import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
# 设置中文字体和负号显示
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"]
plt.rcParams["axes.unicode_minus"] = False
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
from PIL import Image, ImageDraw
import cv2
import os
import csv
# 尝试导入XGBoost和LightGBM
try:
import xgboost as xgb
except ImportError:
xgb = None
print("警告: 未安装XGBoost库,无法使用XGBoost模型")
try:
import lightgbm as lgb
except ImportError:
lgb = None
print("警告: 未安装LightGBM库,无法使用LightGBM模型")
# 定义模型元数据常量,优化LightGBM参数
MODEL_METADATA = {
'svm': ('支持向量机(SVM)', SVC, StandardScaler),
'dt': ('决策树(DT)', DecisionTreeClassifier, None),
'rf': ('随机森林(RF)', RandomForestClassifier, None),
'xgb': ('XGBoost(XGB)', xgb.XGBClassifier if xgb else None, None),
'lgb': ('LightGBM(LGB)', lgb.LGBMClassifier if lgb else None, None),
'mlp': ('多层感知机(MLP)', MLPClassifier, StandardScaler),
'knn': ('K最近邻(KNN)', KNeighborsClassifier, StandardScaler),
'nb': ('高斯朴素贝叶斯(NB)', GaussianNB, None),
}
def get_split_data(digits_dataset):
"""
提取重复的数据集划分逻辑
:param digits_dataset: 手写数字数据集
:return: 划分后的训练集和测试集
"""
X, y = digits_dataset.data, digits_dataset.target
return train_test_split(X, y, test_size=0.3, random_state=42)
class ModelFactory:
@staticmethod
def create_model(model_type):
"""
创建模型和数据标准化器
:param model_type: 模型类型
:return: 模型和数据标准化器
"""
name, model_cls, scaler_cls = MODEL_METADATA[model_type]
if not model_cls:
raise ImportError(f"{name}模型依赖库未安装")
model_params = {
'svm': {'probability': True, 'random_state': 42},
'dt': {'random_state': 42},
'rf': {'n_estimators': 100, 'random_state': 42},
'xgb': {'objective': 'multi:softmax' if xgb else 'multi:softprob', 'random_state': 42},
'lgb': {'objective': 'multiclass', 'random_state': 42, 'num_class': 10,
'max_depth': 5, 'min_child_samples': 10, 'learning_rate': 0.1,
'force_col_wise': True},
'mlp': {'hidden_layer_sizes': (100, 50), 'max_iter': 1000, 'random_state': 42},
'knn': {'n_neighbors': 5, 'weights': 'distance'},
'nb': {},
}.get(model_type, {'random_state': 42})
model = model_cls(**model_params)
scaler = scaler_cls() if scaler_cls else None
return model, scaler
@staticmethod
def train_model(model, X_train, y_train, scaler=None, model_type=None):
"""
训练模型
:param model: 模型
:param X_train: 训练集特征
:param y_train: 训练集标签
:param scaler: 数据标准化器
:param model_type: 模型类型
:return: 训练好的模型
"""
if scaler:
X_train = scaler.fit_transform(X_train)
if model_type == 'lgb' and isinstance(X_train, np.ndarray):
X_train = pd.DataFrame(X_train)
model.fit(X_train, y_train)
return model
@staticmethod
def evaluate_model(model, X_test, y_test, scaler=None, model_type=None):
"""
评估模型
:param model: 模型
:param X_test: 测试集特征
:param y_test: 测试集标签
:param scaler: 数据标准化器
:param model_type: 模型类型
:return: 模型准确率
"""
if scaler:
X_test = scaler.transform(X_test)
if model_type == 'lgb' and isinstance(X_test, np.ndarray) and hasattr(model, 'feature_name_'):
X_test = pd.DataFrame(X_test, columns=model.feature_name_)
y_pred = model.predict(X_test)
return accuracy_score(y_test, y_pred)
@staticmethod
def train_and_evaluate(model_type, X_train, y_train, X_test, y_test):
"""
训练并评估模型
:param model_type: 模型类型
:param X_train: 训练集特征
:param y_train: 训练集标签
:param X_test: 测试集特征
:param y_test: 测试集标签
:return: 训练好的模型、数据标准化器和准确率
"""
try:
model, scaler = ModelFactory.create_model(model_type)
model = ModelFactory.train_model(model, X_train, y_train, scaler, model_type)
accuracy = ModelFactory.evaluate_model(model, X_test, y_test, scaler, model_type)
return model, scaler, accuracy
except Exception as e:
print(f"模型 {model_type} 训练/评估错误: {str(e)}")
raise e
def evaluate_all_models(digits_dataset):
"""
评估所有可用模型
:param digits_dataset: 手写数字数据集
:return: 模型评估结果
"""
print("\n=== 模型评估 ===")
X_train, X_test, y_train, y_test = get_split_data(digits_dataset)
results = []
for model_type, (name, _, _) in MODEL_METADATA.items():
print(f"评估模型: {name} ({model_type})")
if not MODEL_METADATA[model_type][1]:
results.append({"模型名称": name, "准确率": "N/A"})
continue
try:
model, scaler, accuracy = ModelFactory.train_and_evaluate(
model_type, X_train, y_train, X_test, y_test
)
results.append({"模型名称": name, "准确率": f"{accuracy:.4f}"})
except Exception as e:
results.append({"模型名称": name, "准确率": f"错误: {str(e)}"})
results.sort(key=lambda x: float(x["准确率"]) if isinstance(x["准确率"], str) and x["准确率"].replace('.', '',
1).isdigit() else -1,
reverse=True)
print(pd.DataFrame(results))
return results
class HandwritingBoard:
def __init__(self, root, model_factory, digits):
self.root = root
self.root.title("手写数字识别系统 (含模型性能对比)")
self.root.geometry("1000x600") # 减小主窗口尺寸
self.model_factory = model_factory
self.digits = digits
self.model_cache = {}
self.current_model = None
self.scaler = None
self.current_model_type = None
self.has_drawn = False
self.last_x, self.last_y = 0, 0
self.custom_data = []
self.drawing = False
self.data_dir = "custom_digits_data"
if not os.path.exists(self.data_dir):
os.makedirs(self.data_dir)
# 初始化画布尺寸相关变量
self.canvas_width = 600
self.canvas_height = 600
self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255)
self.draw_obj = ImageDraw.Draw(self.image)
self.create_widgets()
self.init_default_model()
self.canvas.bind("<Configure>", self.on_canvas_resize) # 绑定窗口大小改变事件
def create_widgets(self):
"""创建界面组件"""
# 顶部控制栏
top_frame = tk.Frame(self.root)
top_frame.pack(fill=tk.X, padx=10, pady=5) # 减小边距
tk.Label(top_frame, text="选择模型:", font=("Arial", 10)).pack(side=tk.LEFT, padx=5) # 减小字体和边距
self.available_models = []
for key in MODEL_METADATA:
name = MODEL_METADATA[key][0]
if MODEL_METADATA[key][1]:
self.available_models.append((key, name))
self.model_combobox = ttk.Combobox(
top_frame,
values=[name for _, name in self.available_models],
state="readonly",
width=15, # 减小宽度
font=("Arial", 10) # 减小字体
)
self.model_combobox.current(0)
self.model_combobox.bind("<<ComboboxSelected>>", self.on_model_select)
self.model_combobox.pack(side=tk.LEFT, padx=5) # 减小边距
# 中间内容区域
middle_frame = tk.Frame(self.root)
middle_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5) # 减小边距
# 左侧绘图区域
canvas_frame = tk.Frame(middle_frame)
canvas_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(0, 10)) # 减小边距
self.canvas = tk.Canvas(canvas_frame, bg="white")
self.canvas.pack(fill=tk.BOTH, expand=True)
self.canvas.bind("<Button-1>", self.start_draw)
self.canvas.bind("<B1-Motion>", self.draw)
self.canvas.bind("<ButtonRelease-1>", self.stop_draw)
# 添加绘制提示
self.canvas.create_text(self.canvas_width / 2, self.canvas_height / 2, text="绘制数字", fill="gray",
font=("Arial", 16)) # 减小字体
# 右侧控制面板 - 使用grid布局
control_frame = tk.Frame(middle_frame)
control_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
# 使用grid布局排列右侧组件
control_frame.grid_columnconfigure(0, weight=1)
control_frame.grid_columnconfigure(1, weight=1)
# 当前模型
current_model_frame = tk.LabelFrame(control_frame, text="当前模型", font=("Arial", 10, "bold")) # 减小字体
current_model_frame.grid(row=0, column=0, columnspan=2, sticky="ew", pady=(0, 8), padx=3) # 减小边距
self.model_label = tk.Label(current_model_frame, text="支持向量机(SVM)",
font=("Arial", 12), relief=tk.RAISED, padx=8) # 减小字体和边距
self.model_label.pack(fill=tk.X, pady=5) # 减小边距
# 操作按钮 (左侧)
button_frame = tk.LabelFrame(control_frame, text="操作", font=("Arial", 10, "bold")) # 减小字体
button_frame.grid(row=1, column=0, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距
tk.Button(button_frame, text="识别", command=self.recognize,
width=12, height=1, font=("Arial", 10)).pack(fill=tk.X, pady=4) # 减小尺寸和边距
tk.Button(button_frame, text="清除", command=self.clear_canvas,
width=12, height=1, font=("Arial", 10)).pack(fill=tk.X, pady=4) # 减小尺寸和边距
tk.Button(button_frame, text="样本", command=self.show_samples,
width=12, height=1, font=("Arial", 10)).pack(fill=tk.X, pady=4) # 减小尺寸和边距
tk.Button(button_frame, text="对比图表", command=self.show_performance_chart,
width=12, height=1, font=("Arial", 10)).pack(fill=tk.X, pady=4) # 减小尺寸和边距
# 训练集管理 (右侧)
train_set_frame = tk.LabelFrame(control_frame, text="训练集管理", font=("Arial", 10, "bold")) # 减小字体
train_set_frame.grid(row=1, column=1, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距
tk.Button(train_set_frame, text="保存为训练样本", command=self.save_as_training_sample,
width=12, height=1, font=("Arial", 10)).pack( # 减小尺寸和边距
fill=tk.X, pady=4
)
tk.Button(train_set_frame, text="保存全部训练集", command=self.save_all_training_data,
width=12, height=1, font=("Arial", 10)).pack( # 减小尺寸和边距
fill=tk.X, pady=4
)
tk.Button(train_set_frame, text="加载训练集", command=self.load_training_data,
width=12, height=1, font=("Arial", 10)).pack( # 减小尺寸和边距
fill=tk.X, pady=4
)
# 识别结果
result_frame = tk.LabelFrame(control_frame, text="识别结果", font=("Arial", 10, "bold")) # 减小字体
result_frame.grid(row=2, column=0, columnspan=2, sticky="ew", pady=(0, 8), padx=3) # 减小边距
self.result_label = tk.Label(result_frame, text="请绘制数字",
font=("Arial", 24)) # 减小字体
self.result_label.pack(pady=5) # 减小边距
self.prob_label = tk.Label(result_frame, text="", font=("Arial", 10)) # 减小字体
self.prob_label.pack(pady=3) # 减小边距
self.debug_label = tk.Label(result_frame, text="", font=("Arial", 9), wraplength=250) # 减小字体和宽度
self.debug_label.pack(pady=3) # 减小边距
# 置信度可视化 (左侧)
self.confidence_frame = tk.LabelFrame(control_frame, text="识别置信度", font=("Arial", 10, "bold")) # 减小字体
self.confidence_frame.grid(row=3, column=0, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距
self.confidence_canvas = tk.Canvas(self.confidence_frame, bg="white", height=80) # 减小高度
self.confidence_canvas.pack(fill=tk.BOTH, expand=True, padx=3, pady=3) # 减小边距
# 可能的数字列表 (左侧)
self.candidates_frame = tk.LabelFrame(control_frame, text="可能的数字", font=("Arial", 10, "bold")) # 减小字体
self.candidates_frame.grid(row=4, column=0, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距
self.candidates_tree = ttk.Treeview(self.candidates_frame, columns=("数字", "概率"), show="headings")
self.candidates_tree.column("数字", width=70, anchor=tk.CENTER) # 减小列宽
self.candidates_tree.column("概率", width=70, anchor=tk.CENTER) # 减小列宽
self.candidates_tree.heading("数字", text="数字")
self.candidates_tree.heading("概率", text="概率")
self.candidates_tree.pack(fill=tk.BOTH, expand=True, padx=3, pady=3) # 减小边距
# 模型性能对比 (右侧,与置信度和候选数字并列)
self.performance_frame = tk.LabelFrame(control_frame, text="模型性能对比", font=("Arial", 10, "bold")) # 减小字体
self.performance_frame.grid(row=3, column=1, rowspan=2, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距
self.create_performance_table()
def create_performance_table(self):
"""创建模型性能表格"""
for widget in self.performance_frame.winfo_children():
widget.destroy()
columns = ("模型名称", "准确率")
self.performance_tree = ttk.Treeview(self.performance_frame, columns=columns, show="headings")
self.performance_tree.column("模型名称", width=120, anchor=tk.W) # 减小列宽
self.performance_tree.column("准确率", width=80, anchor=tk.CENTER) # 减小列宽
self.performance_tree.heading("模型名称", text="模型名称")
self.performance_tree.heading("准确率", text="准确率")
scrollbar = ttk.Scrollbar(self.performance_frame, orient=tk.VERTICAL, command=self.performance_tree.yview)
self.performance_tree.configure(yscroll=scrollbar.set)
self.performance_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
self.load_performance_data()
def load_performance_data(self):
"""加载模型性能数据"""
results = evaluate_all_models(self.digits)
for item in self.performance_tree.get_children():
self.performance_tree.delete(item)
for i, result in enumerate(results):
tag = "highlight" if i == 0 and isinstance(result["准确率"], str) and result["准确率"].replace('.', '',
1).isdigit() else ""
self.performance_tree.insert("", tk.END, values=(result["模型名称"], result["准确率"]), tags=(tag,))
self.performance_tree.tag_configure("highlight", background="#e6f7ff", font=("Arial", 9, "bold")) # 减小字体
def show_performance_chart(self):
"""显示模型性能对比图表"""
results = evaluate_all_models(self.digits)
valid_results = []
for result in results:
try:
accuracy = float(result["准确率"])
valid_results.append((result["模型名称"], accuracy))
except (ValueError, TypeError):
continue
if not valid_results:
messagebox.showinfo("提示", "没有可用的模型性能数据来生成图表")
return
valid_results.sort(key=lambda x: x[1], reverse=True)
plt.figure(figsize=(10, 6)) # 减小图表尺寸
models, accuracies = zip(*valid_results)
bars = plt.barh(models, accuracies, color='skyblue')
plt.xlabel('准确率', fontsize=10) # 减小字体
plt.ylabel('模型', fontsize=10) # 减小字体
plt.title('各模型在手写数字识别上的性能对比', fontsize=12) # 减小字体
plt.xlim(0, 1.05)
for bar in bars:
width = bar.get_width()
plt.text(width + 0.01, bar.get_y() + bar.get_height() / 2,
f'{width:.4f}', ha='left', va='center', fontsize=8) # 减小字体
plt.tight_layout()
plt.show()
plt.close()
def start_draw(self, event):
"""开始绘制事件处理"""
self.drawing = True
self.last_x, self.last_y = event.x, event.y
def draw(self, event):
"""绘制事件处理"""
if not self.drawing:
return
x, y = event.x, event.y
self.canvas.create_oval(x - 8, y - 8, x + 8, y + 8, fill="black") # 减小绘制笔触
self.draw_obj.line([self.last_x, self.last_y, x, y], fill=0, width=16) # 减小绘制笔触
self.last_x, self.last_y = x, y
def stop_draw(self, event):
"""停止绘制事件处理"""
self.drawing = False
self.has_drawn = True
def clear_canvas(self):
"""清除画布"""
self.canvas.delete("all")
# 更新画布尺寸相关状态
self.canvas_width = self.canvas.winfo_width()
self.canvas_height = self.canvas.winfo_height()
self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255)
self.draw_obj = ImageDraw.Draw(self.image)
self.result_label.config(text="请绘制数字")
self.prob_label.config(text="")
self.debug_label.config(text="")
self.canvas.create_text(self.canvas_width / 2, self.canvas_height / 2, text="绘制数字", fill="gray",
font=("Arial", 16)) # 减小字体
self.has_drawn = False
self.clear_confidence_display()
def clear_confidence_display(self):
"""清除置信度显示"""
self.confidence_canvas.delete("all")
for item in self.candidates_tree.get_children():
self.candidates_tree.delete(item)
def preprocess_image(self):
"""预处理手写数字图像"""
img_array = np.array(self.image)
img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
_, img_array = cv2.threshold(img_array, 127, 255, cv2.THRESH_BINARY_INV)
contours, _ = cv2.findContours(img_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
self.debug_label.config(text="未检测到有效数字,请重新绘制")
return np.zeros(64)
c = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(c)
digit = img_array[y:y + h, x:x + w]
size = max(w, h)
padded = np.ones((size, size), dtype=np.uint8) * 255
offset_x = (size - w) // 2
offset_y = (size - h) // 2
padded[offset_y:offset_y + h, offset_x:offset_x + w] = digit
resized = cv2.resize(padded, (8, 8), interpolation=cv2.INTER_AREA)
normalized = 16 - (resized / 255 * 16).astype(np.uint8)
return normalized.flatten()
def recognize(self):
"""识别手写数字并显示置信度和候选数字"""
if not self.has_drawn:
self.debug_label.config(text="请先绘制数字再识别")
return
if self.current_model_type is None:
self.debug_label.config(text="模型类型未正确设置,请重新加载模型")
return
if self.current_model is None:
self.debug_label.config(text="模型未加载,请选择并加载模型")
return
img = self.preprocess_image()
if img.sum() == 0:
self.clear_confidence_display()
return
img_input = img.reshape(1, -1)
try:
if self.scaler:
img_input = self.scaler.transform(img_input)
if self.current_model_type == 'lgb' and hasattr(self.current_model, 'feature_name_'):
img_input = pd.DataFrame(img_input, columns=self.current_model.feature_name_)
pred = self.current_model.predict(img_input)[0]
# 显示识别结果
self.result_label.config(text=f"识别结果: {pred}")
# 检查模型是否支持概率预测
if hasattr(self.current_model, 'predict_proba'):
probs = self.current_model.predict_proba(img_input)[0]
confidence = probs[pred]
# 显示置信度
self.prob_label.config(text=f"置信度: {confidence:.2%}")
# 更新置信度可视化
self.update_confidence_display(confidence)
# 显示前3个候选数字
top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3]
self.update_candidates_display(top3)
# 更新调试信息
prob_text = "\n".join([f"数字 {i}: 概率 {p:.2%}" for i, p in top3])
self.debug_label.config(text=prob_text)
else:
self.prob_label.config(text="置信度: 该模型不支持概率输出")
self.debug_label.config(text="")
self.clear_confidence_display()
except Exception as e:
self.debug_label.config(text=f"识别错误: {str(e)}")
self.clear_confidence_display()
print("识别异常:", e)
def update_confidence_display(self, confidence):
"""更新置信度可视化"""
self.confidence_canvas.delete("all")
# 获取画布宽度用于动态调整
canvas_width = self.confidence_canvas.winfo_width() or 300 # 减小默认宽度
# 绘制背景条
self.confidence_canvas.create_rectangle(15, 15, canvas_width - 15, 45, fill="#f0f0f0", outline="gray") # 减小尺寸
# 绘制置信度条
bar_width = int((canvas_width - 30) * confidence)
color = self.get_confidence_color(confidence)
self.confidence_canvas.create_rectangle(15, 15, 15 + bar_width, 45, fill=color, outline="") # 减小尺寸
# 绘制置信度文本
self.confidence_canvas.create_text((canvas_width) / 2, 30, text=f"置信度: {confidence:.2%}", font=("Arial", 9)) # 减小字体
# 绘制数字0-100刻度
for i in range(0, 11):
x_pos = 15 + i * (canvas_width - 30) / 10
self.confidence_canvas.create_line(x_pos, 45, x_pos, 50, width=1) # 减小尺寸
if i % 2 == 0: # 每20%显示一个数字
self.confidence_canvas.create_text(x_pos, 60, text=f"{i * 10}", font=("Arial", 7)) # 减小字体
def get_confidence_color(self, confidence):
"""根据置信度返回对应的颜色"""
if confidence >= 0.9:
return "#2ecc71" # 绿色 (高置信度)
elif confidence >= 0.7:
return "#f1c40f" # 黄色 (中等置信度)
else:
return "#e74c3c" # 红色 (低置信度)
def update_candidates_display(self, candidates):
"""更新候选数字显示"""
# 清空现有项
for item in self.candidates_tree.get_children():
self.candidates_tree.delete(item)
# 添加新项
for digit, prob in candidates:
# 去掉高亮标签
tag = ""
self.candidates_tree.insert("", tk.END, values=(digit, f"{prob:.2%}"), tags=(tag,))
def show_samples(self):
"""显示手写数字样本"""
plt.figure(figsize=(10, 5)) # 减小图表尺寸
for i in range(10):
plt.subplot(2, 5, i + 1)
sample_idx = np.where(self.digits.target == i)[0][0]
plt.imshow(self.digits.images[sample_idx], cmap="gray")
plt.title(f"数字 {i}", fontsize=10) # 减小字体
plt.axis("off")
plt.tight_layout()
plt.show()
plt.close()
def on_model_select(self, event):
"""模型选择事件处理"""
selected_name = self.model_combobox.get()
model_type = {v: k for k, v in self.available_models}[selected_name]
self.change_model(model_type)
def change_model(self, model_type):
"""切换模型"""
print(f"触发 change_model,选中模型键: {model_type}")
model_name = MODEL_METADATA.get(model_type, (model_type,))[0]
if model_type in self.model_cache:
self.current_model, self.scaler, accuracy, self.current_model_type = self.model_cache[model_type]
self.model_label.config(
text=f"{model_name} (准确率:{accuracy:.4f})"
)
self.debug_label.config(text=f"已从缓存加载 {model_name}")
return
print(f"\n=== 开始加载 {model_name} 模型 ===")
X_train, X_test, y_train, y_test = get_split_data(self.digits)
try:
self.current_model, self.scaler, accuracy = ModelFactory.train_and_evaluate(
model_type, X_train, y_train, X_test, y_test
)
self.current_model_type = model_type
self.model_cache[model_type] = (self.current_model, self.scaler, accuracy, self.current_model_type)
self.model_label.config(
text=f"{model_name} (准确率:{accuracy:.4f})"
)
self.debug_label.config(text=f"模型加载完成,测试集准确率: {accuracy:.4f}")
self.clear_canvas()
print(f"=== {model_name} 加载完成,准确率 {accuracy:.4f} ===\n")
self.load_performance_data()
except Exception as e:
self.debug_label.config(text=f"模型加载失败: {str(e)}")
print(f"加载异常: {e}\n")
def init_default_model(self):
"""初始化默认模型"""
default_model_type = 'svm'
self.change_model(default_model_type)
def save_as_training_sample(self):
"""保存手写数字作为训练样本"""
if not self.has_drawn:
self.debug_label.config(text="请先绘制数字再保存")
return
img = self.preprocess_image()
if img.sum() == 0:
self.debug_label.config(text="未检测到有效数字,无法保存")
return
label_window = tk.Toplevel(self.root)
label_window.title("输入数字标签")
label_window.geometry("300x150") # 减小窗口尺寸
tk.Label(label_window, text="请输入您绘制的数字 (0-9):",
font=("Arial", 10)).pack(pady=10) # 减小字体和边距
entry = tk.Entry(label_window, font=("Arial", 12), width=8) # 减小字体和宽度
entry.pack(pady=5) # 减小边距
entry.focus_set()
def save_with_label():
try:
label = int(entry.get())
if not (0 <= label <= 9):
raise ValueError("标签必须是0到9之间的数字")
self.custom_data.append((img.tolist(), label))
self.debug_label.config(text=f"已保存数字 {label} 到训练集 (当前共有 {len(self.custom_data)} 个样本)")
label_window.destroy()
except ValueError as e:
self.debug_label.config(text=f"输入错误: {str(e)}")
tk.Button(label_window, text="保存", command=save_with_label,
width=10, height=1, font=("Arial", 10)).pack(pady=8) # 减小尺寸和边距
label_window.grab_set()
def save_all_training_data(self):
"""保存所有训练数据"""
if not self.custom_data:
self.debug_label.config(text="没有训练数据可保存")
return
file_path = filedialog.asksaveasfilename(
defaultextension=".csv",
filetypes=[("CSV文件", "*.csv"), ("所有文件", "*.*")],
initialfile="custom_digits_training.csv",
title="保存训练集"
)
if not file_path:
return
try:
with open(file_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([f'pixel{i}' for i in range(64)] + ['label'])
for img_data, label in self.custom_data:
writer.writerow(img_data + [label])
self.debug_label.config(text=f"已保存 {len(self.custom_data)} 个训练样本到 {file_path}")
except Exception as e:
self.debug_label.config(text=f"保存失败: {str(e)}")
print(f"保存训练集异常: {e}")
def load_training_data(self):
"""加载训练数据"""
file_path = filedialog.askopenfilename(
filetypes=[("CSV文件", "*.csv"), ("所有文件", "*.*")],
title="加载训练集"
)
if not file_path:
return
try:
self.custom_data = []
with open(file_path, 'r', newline='', encoding='utf-8') as f:
reader = csv.reader(f)
next(reader) # 跳过标题行
for row in reader:
if len(row) < 65: # 确保数据完整
continue
img_data = [int(pixel) for pixel in row[:64]]
label = int(row[64])
self.custom_data.append((img_data, label))
self.debug_label.config(text=f"已从 {file_path} 加载 {len(self.custom_data)} 个训练样本")
except Exception as e:
self.debug_label.config(text=f"加载失败: {str(e)}")
print(f"加载训练集异常: {e}")
def on_canvas_resize(self, event):
"""处理画布大小改变事件"""
# 忽略初始尺寸为1的事件
if event.width <= 1 or event.height <= 1:
return
# 更新画布尺寸
self.canvas_width = event.width
self.canvas_height = event.height
# 重新创建图像并居中绘制提示文本
self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255)
self.draw_obj = ImageDraw.Draw(self.image)
# 清除并重新绘制提示
self.canvas.delete("all")
self.canvas.create_text(self.canvas_width / 2, self.canvas_height / 2, text="绘制数字", fill="gray",
font=("Arial", 16)) # 减小字体
def run(self):
"""运行主循环"""
self.root.mainloop()
if __name__ == "__main__":
digits = load_digits()
root = tk.Tk()
app = HandwritingBoard(root, ModelFactory, digits)
app.run()
帮我优化代码