课程设计与要求:
实验18 手写数字识别程序设计与实现
实验类型:设计性实验
实验学时:8
涉及的知识点:SVM、决策树、随机森林、XGBoost和LightGBM机器学习算法的综合应用
一、 实验目的
1、 了解机器学习算法应用项目设计流程与基本方法。
2、 掌握SVM应用设计与K折交叉验证法获得测试数据。
3、 熟悉两种以上不同类型机器学习算法及应用。
4、 掌握各类机器学习算法的区别、优缺点;会应用网格搜索选择最优超参数。
5、 掌握分类任务的性能指标评价方法。
二、 实验要求
1、 使用anaconda集成开发环境完成课程设计,代码的可维护性好,有必要的注释和相应的文档。
2、 能够识别符合分辨率要求的手写数字。
3、 构建不同模型实现手写数字分类识别,至少要对比两种方法,如决策树、支持向量机、随机森林、XGBoost和LightGBM等。对比不同模型的分类性能报告,评价模型好坏。
4、 数据集采用sklearn.datasets中的digits,测试集数据可以用自己手写产生或者从digits中拆分。
三、 设计指标
1、 完整的设计文档
1) 系统的需求分析
2) 系统的概要设计
3) 详细设计与实现
4) 系统测试方法
2、 运行画面截图
3、 每一部分附上关键性代码
4、 项目总结
四、 预习与参考
1、 教材有关决策树、SVM(支持向量机)、随机森林、XGBoost和LightGBM有关章节。
2、 课程PPT有关内容。
3、 中国知网有关手写数字识别的文献资料。
五、 考核形式
根据提交的设计文档完成程度以及程序功能的实现情况(要求演示)进行考核:
无任何文档,无程序,得 0 分;
文档描述不清楚,思路混乱,程序不能运行,2分;
文档描述清晰,程序实现了基本功能,3.5分;
文档描述清晰准确,思路清晰,程序实现了要求的所有功能,4. 5分;
文档完备,设计合理有创新,报告清晰明确,深入分析了自己进行实验的体会感想,程序实现了全部功能,功能完善,并有其它的创新实现,5分。
六、 实验报告要求
1、 实验目的结合自己个人的实际情况书写,不要雷同。
2、 项目概要设计说明书(描述软件系统架构、逻辑架构、物理架构、部署结构、功能架构及关键技术,关键业务模块需通过UML图进行详细描述)、需求规格说明书(包括功能设计、非功能性设计、系统用例)。
3、 项目设计运行截图。
代码程序:
# ========================
# 导入必要的库
# ========================
import numpy as np # 数值计算库
import matplotlib.pyplot as plt # 绘图库
import pandas as pd # 数据处理库
import tkinter as tk # GUI库
from tkinter import ttk, filedialog, messagebox # GUI组件
from PIL import Image, ImageDraw # 图像处理库
import cv2 # 计算机视觉库
import os # 操作系统接口
import csv # CSV文件处理
from sklearn.datasets import load_digits # 加载数字数据集
from sklearn.model_selection import train_test_split # 数据集划分
from sklearn.svm import SVC # 支持向量机模型
from sklearn.tree import DecisionTreeClassifier # 决策树模型
from sklearn.ensemble import RandomForestClassifier # 随机森林模型
from sklearn.neural_network import MLPClassifier # 多层感知机模型
from sklearn.neighbors import KNeighborsClassifier # K近邻模型
from sklearn.naive_bayes import GaussianNB # 朴素贝叶斯模型
from sklearn.metrics import accuracy_score # 准确率评估
from sklearn.preprocessing import StandardScaler # 数据标准化
# 设置中文字体和负号显示(解决中文乱码问题)
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"]
plt.rcParams["axes.unicode_minus"] = False
# ========================
# 尝试导入可选模型库
# ========================
XGB_INSTALLED = False # 标记XGBoost是否安装
LGB_INSTALLED = False # 标记LightGBM是否安装
try:
import xgboost as xgb # XGBoost模型
XGB_INSTALLED = True
except ImportError:
print("警告: 未安装XGBoost库,无法使用XGBoost模型")
try:
import lightgbm as lgb # LightGBM模型
LGB_INSTALLED = True
except ImportError:
print("警告: 未安装LightGBM库,无法使用LightGBM模型")
# ========================
# 模型配置
# ========================
# 定义模型元数据(包含模型名称、类、标准化器和参数)
MODEL_METADATA = {
'svm': ('支持向量机(SVM)', SVC, StandardScaler, {'probability': True, 'random_state': 42}),
'dt': ('决策树(DT)', DecisionTreeClassifier, None, {'random_state': 42}),
'rf': ('随机森林(RF)', RandomForestClassifier, None, {'n_estimators': 100, 'random_state': 42}),
'mlp': ('多层感知机(MLP)', MLPClassifier, StandardScaler, {'hidden_layer_sizes': (100, 50), 'max_iter': 500, 'random_state': 42}),
'knn': ('K最近邻(KNN)', KNeighborsClassifier, StandardScaler, {'n_neighbors': 5, 'weights': 'distance'}),
'nb': ('高斯朴素贝叶斯(NB)', GaussianNB, None, {}),
}
# 添加可选模型(如果已安装)
if XGB_INSTALLED:
MODEL_METADATA['xgb'] = ('XGBoost(XGB)', xgb.XGBClassifier, None, {'objective': 'multi:softmax', 'random_state': 42})
if LGB_INSTALLED:
MODEL_METADATA['lgb'] = ('LightGBM(LGB)', lgb.LGBMClassifier, None, {
'objective': 'multiclass',
'random_state': 42,
'num_class': 10,
'max_depth': 5,
'min_child_samples': 10,
'learning_rate': 0.1,
'force_col_wise': True
})
# ========================
# 模型工厂类 - 负责模型创建、训练和评估
# ========================
class ModelFactory:
@staticmethod
def get_split_data(digits_dataset):
"""数据集划分"""
X, y = digits_dataset.data, digits_dataset.target # 获取特征和标签
# 划分训练集和测试集(70%训练,30%测试)
return train_test_split(X, y, test_size=0.3, random_state=42)
@classmethod
def create_model(cls, model_type):
"""创建模型和数据标准化器"""
# 检查模型类型是否有效
if model_type not in MODEL_METADATA:
raise ValueError(f"未知模型类型: {model_type}")
# 从配置中获取模型信息
name, model_cls, scaler_cls, params = MODEL_METADATA[model_type]
# 创建模型实例和标准化器
model = model_cls(**params)
scaler = scaler_cls() if scaler_cls else None
return model, scaler
@staticmethod
def train_model(model, X_train, y_train, scaler=None, model_type=None):
"""训练模型"""
# 数据标准化处理
if scaler:
X_train = scaler.fit_transform(X_train)
# LightGBM特殊处理(需要DataFrame格式)
if model_type == 'lgb' and isinstance(X_train, np.ndarray):
X_train = pd.DataFrame(X_train)
# 训练模型
model.fit(X_train, y_train)
return model
@staticmethod
def evaluate_model(model, X_test, y_test, scaler=None, model_type=None):
"""评估模型"""
# 数据标准化处理
if scaler:
X_test = scaler.transform(X_test)
# LightGBM特殊处理
if model_type == 'lgb' and isinstance(X_test, np.ndarray) and hasattr(model, 'feature_name_'):
X_test = pd.DataFrame(X_test, columns=model.feature_name_)
# 预测并计算准确率
y_pred = model.predict(X_test)
return accuracy_score(y_test, y_pred)
@classmethod
def train_and_evaluate(cls, model_type, X_train, y_train, X_test, y_test):
"""训练并评估模型"""
try:
# 创建模型
model, scaler = cls.create_model(model_type)
# 训练模型
model = cls.train_model(model, X_train, y_train, scaler, model_type)
# 评估模型
accuracy = cls.evaluate_model(model, X_test, y_test, scaler, model_type)
return model, scaler, accuracy
except Exception as e:
print(f"模型 {model_type} 训练/评估错误: {str(e)}")
raise
@classmethod
def evaluate_all_models(cls, digits_dataset):
"""评估所有可用模型"""
print("\n=== 模型评估 ===")
# 划分数据集
X_train, X_test, y_train, y_test = cls.get_split_data(digits_dataset)
results = [] # 存储结果
# 遍历所有模型
for model_type in MODEL_METADATA:
name = MODEL_METADATA[model_type][0]
print(f"评估模型: {name} ({model_type})")
# 检查模型是否可用
if not MODEL_METADATA[model_type][1]:
results.append({"模型名称": name, "准确率": "N/A"})
continue
try:
# 训练并评估模型
_, _, accuracy = cls.train_and_evaluate(
model_type, X_train, y_train, X_test, y_test
)
results.append({"模型名称": name, "准确率": f"{accuracy:.4f}"})
except Exception as e:
results.append({"模型名称": name, "准确率": f"错误: {str(e)}"})
# 按准确率排序
results.sort(
key=lambda x: float(x["准确率"])
if isinstance(x["准确率"], str) and x["准确率"].replace('.', '', 1).isdigit()
else -1,
reverse=True
)
# 打印结果
print(pd.DataFrame(results))
return results
# ========================
# 手写板类 - GUI界面和绘图功能
# ========================
class HandwritingBoard:
CANVAS_SIZE = 300 # 固定画布尺寸
BRUSH_SIZE = 12 # 画笔大小
def __init__(self, root, model_factory, digits):
# 初始化主窗口
self.root = root
self.root.title("手写数字识别系统")
self.root.geometry("1000x700") # 设置窗口大小
# 模型和数据相关
self.model_factory = model_factory # 模型工厂
self.digits = digits # 数字数据集
self.model_cache = {} # 模型缓存(提高切换速度)
self.current_model = None # 当前使用的模型
self.scaler = None # 数据标准化器
self.current_model_type = None # 当前模型类型
self.has_drawn = False # 标记是否已绘制数字
self.custom_data = [] # 存储自定义训练数据
# 绘图相关状态
self.drawing = False # 是否正在绘制
self.last_x = self.last_y = 0 # 上次绘制位置
# 创建自定义数据目录
self.data_dir = "custom_digits_data"
os.makedirs(self.data_dir, exist_ok=True)
# 初始化画布(PIL图像)
self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) # 创建白色背景图像
self.draw_obj = ImageDraw.Draw(self.image) # 创建绘图对象
# 创建界面组件
self.create_widgets()
# 初始化默认模型
self.init_default_model()
def create_widgets(self):
"""创建界面组件"""
# 创建主框架
main_frame = tk.Frame(self.root)
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# 1. 模型选择区域
model_frame = tk.LabelFrame(main_frame, text="模型选择", font=("Arial", 10, "bold"))
model_frame.grid(row=0, column=0, columnspan=2, sticky="ew", padx=5, pady=5)
model_frame.grid_columnconfigure(1, weight=1) # 让模型标签可以扩展
# 模型选择标签
tk.Label(model_frame, text="选择模型:", font=("Arial", 10)).grid(row=0, column=0, padx=5, pady=5, sticky="w")
# 获取可用模型列表
self.available_models = []
for model_type, (name, _, _, _) in MODEL_METADATA.items():
if MODEL_METADATA[model_type][1]:
self.available_models.append((model_type, name))
# 模型选择下拉框
self.model_var = tk.StringVar()
self.model_combobox = ttk.Combobox(
model_frame,
textvariable=self.model_var,
values=[name for _, name in self.available_models],
state="readonly",
width=25,
font=("Arial", 10)
)
self.model_combobox.current(0) # 设置默认选项
self.model_combobox.bind("<<ComboboxSelected>>", self.on_model_select) # 绑定选择事件
self.model_combobox.grid(row=0, column=1, padx=5, pady=5, sticky="ew")
# 模型信息标签(显示准确率)
self.model_label = tk.Label(
model_frame,
text="",
font=("Arial", 10),
relief=tk.SUNKEN,
padx=5,
pady=2
)
self.model_label.grid(row=0, column=2, padx=5, pady=5, sticky="ew")
# 2. 左侧绘图区域和右侧结果区域
# 左侧绘图区域
left_frame = tk.LabelFrame(main_frame, text="绘制区域", font=("Arial", 10, "bold"))
left_frame.grid(row=1, column=0, padx=5, pady=5, sticky="nsew")
# 绘图画布
self.canvas = tk.Canvas(left_frame, bg="white", width=self.CANVAS_SIZE, height=self.CANVAS_SIZE)
self.canvas.pack(padx=10, pady=10)
# 绑定绘图事件
self.canvas.bind("<Button-1>", self.start_draw) # 鼠标按下
self.canvas.bind("<B1-Motion>", self.draw) # 鼠标拖动
self.canvas.bind("<ButtonRelease-1>", self.stop_draw) # 鼠标释放
# 添加绘制提示
self.canvas.create_text(
self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2,
text="绘制数字", fill="gray", font=("Arial", 16)
)
# 绘图控制按钮
btn_frame = tk.Frame(left_frame)
btn_frame.pack(fill=tk.X, pady=(0, 10))
# 功能按钮
tk.Button(btn_frame, text="识别", command=self.recognize, width=8).pack(side=tk.LEFT, padx=5)
tk.Button(btn_frame, text="清除", command=self.clear_canvas, width=8).pack(side=tk.LEFT, padx=5)
tk.Button(btn_frame, text="样本", command=self.show_samples, width=8).pack(side=tk.LEFT, padx=5)
# 右侧结果区域
right_frame = tk.Frame(main_frame)
right_frame.grid(row=1, column=1, padx=5, pady=5, sticky="nsew")
# 2.1 识别结果区域
result_frame = tk.LabelFrame(right_frame, text="识别结果", font=("Arial", 10, "bold"))
result_frame.pack(fill=tk.X, padx=5, pady=5)
# 结果显示标签
self.result_label = tk.Label(
result_frame,
text="请绘制数字",
font=("Arial", 24),
pady=10
)
self.result_label.pack()
# 置信度显示标签
self.prob_label = tk.Label(
result_frame,
text="",
font=("Arial", 12)
)
self.prob_label.pack()
# 2.2 置信度可视化区域
confidence_frame = tk.LabelFrame(right_frame, text="识别置信度", font=("Arial", 10, "bold"))
confidence_frame.pack(fill=tk.X, padx=5, pady=5)
# 置信度画布(条形图)
self.confidence_canvas = tk.Canvas(
confidence_frame,
bg="white",
height=50
)
self.confidence_canvas.pack(fill=tk.X, padx=10, pady=10)
self.confidence_canvas.create_text(
150, 25,
text="识别后显示置信度",
fill="gray",
font=("Arial", 10)
)
# 2.3 候选数字区域
candidates_frame = tk.LabelFrame(right_frame, text="可能的数字", font=("Arial", 10, "bold"))
candidates_frame.pack(fill=tk.X, padx=5, pady=5)
# 候选数字表格
columns = ("数字", "概率")
self.candidates_tree = ttk.Treeview(
candidates_frame,
columns=columns,
show="headings",
height=4
)
# 配置表格列
for col in columns:
self.candidates_tree.heading(col, text=col)
self.candidates_tree.column(col, width=80, anchor=tk.CENTER)
# 添加滚动条
scrollbar = ttk.Scrollbar(
candidates_frame,
orient=tk.VERTICAL,
command=self.candidates_tree.yview
)
self.candidates_tree.configure(yscroll=scrollbar.set)
self.candidates_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5)
# 3. 模型性能对比和训练集管理区域
# 3.1 模型性能对比区域
performance_frame = tk.LabelFrame(main_frame, text="模型性能对比", font=("Arial", 10, "bold"))
performance_frame.grid(row=2, column=0, padx=5, pady=5, sticky="nsew")
# 性能表格
columns = ("模型名称", "准确率")
self.performance_tree = ttk.Treeview(
performance_frame,
columns=columns,
show="headings",
height=8
)
# 配置表格列
for col in columns:
self.performance_tree.heading(col, text=col)
self.performance_tree.column(col, width=120, anchor=tk.CENTER)
# 添加滚动条
scrollbar = ttk.Scrollbar(
performance_frame,
orient=tk.VERTICAL,
command=self.performance_tree.yview
)
self.performance_tree.configure(yscroll=scrollbar.set)
self.performance_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5)
# 3.2 训练集管理区域
train_frame = tk.LabelFrame(main_frame, text="训练集管理", font=("Arial", 10, "bold"))
train_frame.grid(row=2, column=1, padx=5, pady=5, sticky="nsew")
# 训练集管理按钮
tk.Button(
train_frame,
text="保存为训练样本",
command=self.save_as_training_sample,
width=18,
height=2
).grid(row=0, column=0, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="保存全部训练集",
command=self.save_all_training_data,
width=18,
height=2
).grid(row=0, column=1, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="加载训练集",
command=self.load_training_data,
width=18,
height=2
).grid(row=1, column=0, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="性能图表",
command=self.show_performance_chart,
width=18,
height=2
).grid(row=1, column=1, padx=5, pady=5, sticky="ew")
# 4. 状态栏
self.status_var = tk.StringVar(value="就绪")
status_bar = tk.Label(
self.root,
textvariable=self.status_var,
bd=1,
relief=tk.SUNKEN,
anchor=tk.W,
font=("Arial", 10)
)
status_bar.pack(side=tk.BOTTOM, fill=tk.X)
# 配置布局权重
main_frame.grid_columnconfigure(0, weight=1)
main_frame.grid_columnconfigure(1, weight=1)
main_frame.grid_rowconfigure(1, weight=1)
main_frame.grid_rowconfigure(2, weight=1)
# ========================
# 绘图功能
# ========================
def start_draw(self, event):
"""开始绘制"""
self.drawing = True
self.last_x, self.last_y = event.x, event.y
def draw(self, event):
"""绘制"""
if not self.drawing:
return
x, y = event.x, event.y
# 在画布上绘制
self.canvas.create_line(
self.last_x, self.last_y, x, y,
fill="black",
width=self.BRUSH_SIZE,
capstyle=tk.ROUND,
smooth=True
)
# 在图像上绘制(用于后续处理)
self.draw_obj.line(
[self.last_x, self.last_y, x, y],
fill=0, # 黑色
width=self.BRUSH_SIZE
)
# 更新位置
self.last_x, self.last_y = x, y
def stop_draw(self, event):
"""停止绘制"""
self.drawing = False
self.has_drawn = True
self.status_var.set("已绘制数字,点击'识别'进行识别")
def clear_canvas(self):
"""清除画布"""
# 清除画布内容
self.canvas.delete("all")
# 重置图像
self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) # 白色背景
self.draw_obj = ImageDraw.Draw(self.image)
# 添加绘制提示
self.canvas.create_text(
self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2,
text="绘制数字", fill="gray", font=("Arial", 16)
)
# 重置结果显示
self.result_label.config(text="请绘制数字")
self.prob_label.config(text="")
self.clear_confidence_display()
self.has_drawn = False
self.status_var.set("画布已清除")
def clear_confidence_display(self):
"""清除置信度显示"""
self.confidence_canvas.delete("all")
self.confidence_canvas.create_text(
150, 25,
text="识别后显示置信度",
fill="gray",
font=("Arial", 10)
)
# 清空候选数字表格
for item in self.candidates_tree.get_children():
self.candidates_tree.delete(item)
# ========================
# 图像处理功能
# ========================
def preprocess_image(self):
"""预处理手写数字图像"""
# 将PIL图像转换为NumPy数组
img_array = np.array(self.image)
# 1. 高斯模糊降噪
img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
# 2. 二值化(转换为黑白图像)
_, img_array = cv2.threshold(img_array, 127, 255, cv2.THRESH_BINARY_INV)
# 3. 轮廓检测(查找数字轮廓)
contours, _ = cv2.findContours(img_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
self.status_var.set("未检测到有效数字,请重新绘制")
return None
# 4. 找到最大轮廓(即数字部分)
c = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(c)
# 5. 提取数字区域
digit = img_array[y:y+h, x:x+w]
# 6. 填充为正方形(保持长宽比)
size = max(w, h)
padded = np.ones((size, size), dtype=np.uint8) * 255 # 白色背景
offset_x = (size - w) // 2
offset_y = (size - h) // 2
padded[offset_y:offset_y+h, offset_x:offset_x+w] = digit
# 7. 缩放为8x8(匹配MNIST数据集格式)
resized = cv2.resize(padded, (8, 8), interpolation=cv2.INTER_AREA)
# 8. 归一化(将像素值从0-255映射到0-16)
normalized = 16 - (resized / 255 * 16).astype(np.uint8)
# 9. 展平为一维数组(64个特征)
return normalized.flatten()
# ========================
# 识别功能
# ========================
def recognize(self):
"""识别手写数字"""
# 检查是否已绘制数字
if not self.has_drawn:
self.status_var.set("请先绘制数字再识别")
return
# 检查模型是否已加载
if self.current_model is None:
self.status_var.set("模型未加载,请选择模型")
return
# 预处理图像
img_array = self.preprocess_image()
if img_array is None:
return
# 重塑为模型输入格式(1个样本,64个特征)
img_input = img_array.reshape(1, -1)
try:
# 数据标准化
if self.scaler:
img_input = self.scaler.transform(img_input)
# LightGBM特殊处理(需要DataFrame格式)
if self.current_model_type == 'lgb' and hasattr(self.current_model, 'feature_name_'):
img_input = pd.DataFrame(img_input, columns=self.current_model.feature_name_)
# 预测数字
pred = self.current_model.predict(img_input)[0]
self.result_label.config(text=f"识别结果: {pred}")
# 概率预测(如果模型支持)
if hasattr(self.current_model, 'predict_proba'):
probs = self.current_model.predict_proba(img_input)[0]
confidence = probs[pred] # 预测结果的置信度
# 更新UI
self.prob_label.config(text=f"置信度: {confidence:.2%}")
self.update_confidence_display(confidence) # 更新置信度可视化
# 显示候选数字(概率最高的3个)
top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3]
self.update_candidates_display(top3)
else:
self.prob_label.config(text="该模型不支持概率输出")
self.clear_confidence_display()
self.status_var.set(f"识别完成: 数字 {pred}")
except Exception as e:
self.status_var.set(f"识别错误: {str(e)}")
self.clear_confidence_display()
# ========================
# UI更新功能
# ========================
def update_confidence_display(self, confidence):
"""更新置信度可视化"""
self.confidence_canvas.delete("all")
# 获取画布宽度
canvas_width = self.confidence_canvas.winfo_width() or 300
# 绘制背景
self.confidence_canvas.create_rectangle(
10, 10, canvas_width - 10, 40,
fill="#f0f0f0",
outline="#cccccc"
)
# 绘制置信度条(根据置信度值)
bar_width = int((canvas_width - 20) * confidence)
color = self.get_confidence_color(confidence) # 根据置信度选择颜色
self.confidence_canvas.create_rectangle(
10, 10, 10 + bar_width, 40,
fill=color,
outline=""
)
# 绘制文本(显示百分比)
self.confidence_canvas.create_text(
canvas_width / 2, 25,
text=f"{confidence:.1%}",
font=("Arial", 10, "bold")
)
# 绘制刻度
for i in range(0, 11):
x_pos = 10 + i * (canvas_width - 20) / 10
self.confidence_canvas.create_line(x_pos, 40, x_pos, 45, width=1)
if i % 2 == 0:
self.confidence_canvas.create_text(x_pos, 55, text=f"{i*10}%", font=("Arial", 8))
def get_confidence_color(self, confidence):
"""根据置信度获取颜色"""
# 高置信度:绿色
if confidence >= 0.9:
return "#4CAF50"
# 中等置信度:黄色
elif confidence >= 0.7:
return "#FFC107"
# 低置信度:红色
else:
return "#F44336"
def update_candidates_display(self, candidates):
"""更新候选数字显示"""
# 清空现有项
for item in self.candidates_tree.get_children():
self.candidates_tree.delete(item)
# 添加新项(候选数字及其概率)
for digit, prob in candidates:
self.candidates_tree.insert(
"", tk.END,
values=(digit, f"{prob:.2%}")
)
# ========================
# 样本显示功能
# ========================
def show_samples(self):
"""显示样本图像"""
plt.figure(figsize=(10, 4))
# 显示0-9每个数字的一个样本
for i in range(10):
plt.subplot(2, 5, i+1)
sample_idx = np.where(self.digits.target == i)[0][0]
plt.imshow(self.digits.images[sample_idx], cmap="gray")
plt.title(f"数字 {i}", fontsize=9)
plt.axis("off")
plt.tight_layout()
plt.show()
# ========================
# 模型管理功能
# ========================
def on_model_select(self, event):
"""模型选择事件处理"""
# 获取选中的模型名称
selected_name = self.model_var.get()
# 查找对应的模型类型
model_type = next(
(k for k, v in self.available_models if v == selected_name),
None
)
if model_type:
# 切换模型
self.change_model(model_type)
def change_model(self, model_type):
"""切换模型"""
model_name = MODEL_METADATA[model_type][0]
# 尝试从缓存加载模型
if model_type in self.model_cache:
self.current_model, self.scaler, accuracy, self.current_model_type = self.model_cache[model_type]
self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})")
self.status_var.set(f"已加载模型: {model_name}")
return
# 加载新模型
self.status_var.set(f"正在加载模型: {model_name}...")
self.root.update() # 更新UI显示状态
try:
# 获取数据集
X_train, X_test, y_train, y_test = self.model_factory.get_split_data(self.digits)
# 训练并评估模型
self.current_model, self.scaler, accuracy = self.model_factory.train_and_evaluate(
model_type, X_train, y_train, X_test, y_test
)
# 缓存模型
self.current_model_type = model_type
self.model_cache[model_type] = (self.current_model, self.scaler, accuracy, self.current_model_type)
# 更新UI
self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})")
self.status_var.set(f"模型加载完成: {model_name}, 准确率: {accuracy:.4f}")
self.clear_canvas()
# 更新性能表格
self.load_performance_data()
except Exception as e:
self.status_var.set(f"模型加载失败: {str(e)}")
self.model_label.config(text="模型加载失败")
def init_default_model(self):
"""初始化默认模型"""
# 设置默认模型并加载
self.model_var.set(self.available_models[0][1])
self.change_model(self.available_models[0][0])
def load_performance_data(self):
"""加载性能数据"""
# 评估所有模型
results = self.model_factory.evaluate_all_models(self.digits)
# 清空表格
for item in self.performance_tree.get_children():
self.performance_tree.delete(item)
# 添加数据到表格
for i, result in enumerate(results):
tag = "highlight" if i == 0 else "" # 高亮显示性能最好的模型
self.performance_tree.insert(
"", tk.END,
values=(result["模型名称"], result["准确率"]),
tags=(tag,)
)
# 配置高亮样式
self.performance_tree.tag_configure("highlight", background="#e6f7ff")
# ========================
# 性能可视化功能
# ========================
def show_performance_chart(self):
"""显示性能图表"""
# 获取性能数据
results = self.model_factory.evaluate_all_models(self.digits)
# 提取有效结果(过滤掉错误数据)
valid_results = []
for result in results:
try:
accuracy = float(result["准确率"])
valid_results.append((result["模型名称"], accuracy))
except ValueError:
continue
if not valid_results:
messagebox.showinfo("提示", "没有可用的性能数据")
return
# 按准确率排序
valid_results.sort(key=lambda x: x[1], reverse=True)
models, accuracies = zip(*valid_results)
# 创建水平条形图
plt.figure(figsize=(10, 5))
bars = plt.barh(models, accuracies, color='#2196F3')
plt.xlabel('准确率', fontsize=10)
plt.ylabel('模型', fontsize=10)
plt.title('模型性能对比', fontsize=12)
plt.xlim(0, 1.05) # 设置X轴范围
# 添加数值标签
for bar in bars:
width = bar.get_width()
plt.text(
width + 0.01,
bar.get_y() + bar.get_height()/2,
f'{width:.4f}',
ha='left',
va='center',
fontsize=8
)
plt.tight_layout()
plt.show()
# ========================
# 训练集管理功能
# ========================
def save_as_training_sample(self):
"""保存为训练样本"""
# 检查是否已绘制数字
if not self.has_drawn:
self.status_var.set("请先绘制数字再保存")
return
# 预处理图像
img_array = self.preprocess_image()
if img_array is None:
return
# 弹出标签输入窗口
label_window = tk.Toplevel(self.root)
label_window.title("输入标签")
label_window.geometry("300x150")
label_window.transient(self.root)
label_window.grab_set() # 模态窗口
# 标签输入提示
tk.Label(
label_window,
text="请输入数字标签 (0-9):",
font=("Arial", 10)
).pack(pady=10)
# 输入框
entry = tk.Entry(label_window, font=("Arial", 12), width=5)
entry.pack(pady=5)
entry.focus_set()
def save_with_label():
"""保存带标签的样本"""
try:
# 验证标签
label = int(entry.get())
if label < 0 or label > 9:
raise ValueError("标签必须是0-9的数字")
# 添加到自定义数据集
self.custom_data.append((img_array.tolist(), label))
self.status_var.set(f"已保存数字 {label} (共 {len(self.custom_data)} 个样本)")
label_window.destroy()
except ValueError as e:
self.status_var.set(f"保存错误: {str(e)}")
# 保存按钮
tk.Button(
label_window,
text="保存",
command=save_with_label,
width=10
).pack(pady=5)
def save_all_training_data(self):
"""保存全部训练数据"""
# 检查是否有数据可保存
if not self.custom_data:
self.status_var.set("没有训练数据可保存")
return
# 弹出文件保存对话框
file_path = filedialog.asksaveasfilename(
defaultextension=".csv",
filetypes=[("CSV文件", "*.csv")],
initialfile="custom_digits.csv",
title="保存训练集"
)
if not file_path:
return
try:
# 写入CSV文件
with open(file_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# 写入表头(64个像素+标签)
writer.writerow([f'pixel{i}' for i in range(64)] + ['label'])
# 写入数据
for img_data, label in self.custom_data:
writer.writerow(img_data + [label])
self.status_var.set(f"已保存 {len(self.custom_data)} 个样本到 {os.path.basename(file_path)}")
except Exception as e:
self.status_var.set(f"保存失败: {str(e)}")
def load_training_data(self):
"""加载训练数据"""
# 弹出文件选择对话框
file_path = filedialog.askopenfilename(
filetypes=[("CSV文件", "*.csv")],
title="加载训练集"
)
if not file_path:
return
try:
self.custom_data = []
# 读取CSV文件
with open(file_path, 'r', newline='', encoding='utf-8') as f:
reader = csv.reader(f)
next(reader) # 跳过标题行
# 解析每一行数据
for row in reader:
if len(row) != 65: # 64像素+1标签
continue
# 提取像素数据和标签
img_data = [float(pixel) for pixel in row[:64]]
label = int(row[64])
self.custom_data.append((img_data, label))
self.status_var.set(f"已加载 {len(self.custom_data)} 个样本")
except Exception as e:
self.status_var.set(f"加载失败: {str(e)}")
# ========================
# 主程序入口
# ========================
def run(self):
"""运行应用"""
self.root.mainloop()
# ========================
# 程序入口
# ========================
if __name__ == "__main__":
digits = load_digits() # 加载数字数据集
root = tk.Tk() # 创建主窗口
app = HandwritingBoard(root, ModelFactory, digits) # 创建应用实例
app.run() # 运行应用
请你根据上面的内容生成符合要求的课程设计:
最新发布