input输入框输入缓存值后背景问题,清除黄色背景

在项目中遇到输入框只要输入缓存值就会变成黄色背景,手动输入就是正常,只需要添加下面一行代码即可解决哟,拿走不谢,?

input:-webkit-autofill { 
	-webkit-box-shadow: 0 0 0 1000px white inset !important;
	}
课程设计与要求: 实验18 手写数字识别程序设计与实现 实验类型:设计性实验 实验学时:8 涉及的知识点:SVM、决策树、随机森林、XGBoost和LightGBM机器学习算法的综合应用 一、 实验目的 1、 了解机器学习算法应用项目设计流程与基本方法。 2、 掌握SVM应用设计与K折交叉验证法获得测试数据。 3、 熟悉两种以上不同类型机器学习算法及应用。 4、 掌握各类机器学习算法的区别、优缺点;会应用网格搜索选择最优超参数。 5、 掌握分类任务的性能指标评价方法。 二、 实验要求 1、 使用anaconda集成开发环境完成课程设计,代码的可维护性好,有必要的注释和相应的文档。 2、 能够识别符合分辨率要求的手写数字。 3、 构建不同模型实现手写数字分类识别,至少要对比两种方法,如决策树、支持向量机、随机森林、XGBoost和LightGBM等。对比不同模型的分类性能报告,评价模型好坏。 4、 数据集采用sklearn.datasets中的digits,测试集数据可以用自己手写产生或者从digits中拆分。 三、 设计指标 1、 完整的设计文档 1) 系统的需求分析 2) 系统的概要设计 3) 详细设计与实现 4) 系统测试方法 2、 运行画面截图 3、 每一部分附上关键性代码 4、 项目总结 四、 预习与参考 1、 教材有关决策树、SVM(支持向量机)、随机森林、XGBoost和LightGBM有关章节。 2、 课程PPT有关内容。 3、 中国知网有关手写数字识别的文献资料。 五、 考核形式 根据提交的设计文档完成程度以及程序功能的实现情况(要求演示)进行考核:  无任何文档,无程序,得 0 分;  文档描述不清楚,思路混乱,程序不能运行,2分;  文档描述清晰,程序实现了基本功能,3.5分;  文档描述清晰准确,思路清晰,程序实现了要求的所有功能,4. 5分;  文档完备,设计合理有创新,报告清晰明确,深入分析了自己进行实验的体会感想,程序实现了全部功能,功能完善,并有其它的创新实现,5分。 六、 实验报告要求 1、 实验目的结合自己个人的实际情况书写,不要雷同。 2、 项目概要设计说明书(描述软件系统架构、逻辑架构、物理架构、部署结构、功能架构及关键技术,关键业务模块需通过UML图进行详细描述)、需求规格说明书(包括功能设计、非功能性设计、系统用例)。 3、 项目设计运行截图。 代码程序: # ======================== # 导入必要的库 # ======================== import numpy as np # 数计算库 import matplotlib.pyplot as plt # 绘图库 import pandas as pd # 数据处理库 import tkinter as tk # GUI库 from tkinter import ttk, filedialog, messagebox # GUI组件 from PIL import Image, ImageDraw # 图像处理库 import cv2 # 计算机视觉库 import os # 操作系统接口 import csv # CSV文件处理 from sklearn.datasets import load_digits # 加载数字数据集 from sklearn.model_selection import train_test_split # 数据集划分 from sklearn.svm import SVC # 支持向量机模型 from sklearn.tree import DecisionTreeClassifier # 决策树模型 from sklearn.ensemble import RandomForestClassifier # 随机森林模型 from sklearn.neural_network import MLPClassifier # 多层感知机模型 from sklearn.neighbors import KNeighborsClassifier # K近邻模型 from sklearn.naive_bayes import GaussianNB # 朴素贝叶斯模型 from sklearn.metrics import accuracy_score # 准确率评估 from sklearn.preprocessing import StandardScaler # 数据标准化 # 设置中文字体和负号显示(解决中文乱码问题) plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"] plt.rcParams["axes.unicode_minus"] = False # ======================== # 尝试导入可选模型库 # ======================== XGB_INSTALLED = False # 标记XGBoost是否安装 LGB_INSTALLED = False # 标记LightGBM是否安装 try: import xgboost as xgb # XGBoost模型 XGB_INSTALLED = True except ImportError: print("警告: 未安装XGBoost库,无法使用XGBoost模型") try: import lightgbm as lgb # LightGBM模型 LGB_INSTALLED = True except ImportError: print("警告: 未安装LightGBM库,无法使用LightGBM模型") # ======================== # 模型配置 # ======================== # 定义模型元数据(包含模型名称、类、标准化器和参数) MODEL_METADATA = { 'svm': ('支持向量机(SVM)', SVC, StandardScaler, {'probability': True, 'random_state': 42}), 'dt': ('决策树(DT)', DecisionTreeClassifier, None, {'random_state': 42}), 'rf': ('随机森林(RF)', RandomForestClassifier, None, {'n_estimators': 100, 'random_state': 42}), 'mlp': ('多层感知机(MLP)', MLPClassifier, StandardScaler, {'hidden_layer_sizes': (100, 50), 'max_iter': 500, 'random_state': 42}), 'knn': ('K最近邻(KNN)', KNeighborsClassifier, StandardScaler, {'n_neighbors': 5, 'weights': 'distance'}), 'nb': ('高斯朴素贝叶斯(NB)', GaussianNB, None, {}), } # 添加可选模型(如果已安装) if XGB_INSTALLED: MODEL_METADATA['xgb'] = ('XGBoost(XGB)', xgb.XGBClassifier, None, {'objective': 'multi:softmax', 'random_state': 42}) if LGB_INSTALLED: MODEL_METADATA['lgb'] = ('LightGBM(LGB)', lgb.LGBMClassifier, None, { 'objective': 'multiclass', 'random_state': 42, 'num_class': 10, 'max_depth': 5, 'min_child_samples': 10, 'learning_rate': 0.1, 'force_col_wise': True }) # ======================== # 模型工厂类 - 负责模型创建、训练和评估 # ======================== class ModelFactory: @staticmethod def get_split_data(digits_dataset): """数据集划分""" X, y = digits_dataset.data, digits_dataset.target # 获取特征和标签 # 划分训练集和测试集(70%训练,30%测试) return train_test_split(X, y, test_size=0.3, random_state=42) @classmethod def create_model(cls, model_type): """创建模型和数据标准化器""" # 检查模型类型是否有效 if model_type not in MODEL_METADATA: raise ValueError(f"未知模型类型: {model_type}") # 从配置中获取模型信息 name, model_cls, scaler_cls, params = MODEL_METADATA[model_type] # 创建模型实例和标准化器 model = model_cls(**params) scaler = scaler_cls() if scaler_cls else None return model, scaler @staticmethod def train_model(model, X_train, y_train, scaler=None, model_type=None): """训练模型""" # 数据标准化处理 if scaler: X_train = scaler.fit_transform(X_train) # LightGBM特殊处理(需要DataFrame格式) if model_type == 'lgb' and isinstance(X_train, np.ndarray): X_train = pd.DataFrame(X_train) # 训练模型 model.fit(X_train, y_train) return model @staticmethod def evaluate_model(model, X_test, y_test, scaler=None, model_type=None): """评估模型""" # 数据标准化处理 if scaler: X_test = scaler.transform(X_test) # LightGBM特殊处理 if model_type == 'lgb' and isinstance(X_test, np.ndarray) and hasattr(model, 'feature_name_'): X_test = pd.DataFrame(X_test, columns=model.feature_name_) # 预测并计算准确率 y_pred = model.predict(X_test) return accuracy_score(y_test, y_pred) @classmethod def train_and_evaluate(cls, model_type, X_train, y_train, X_test, y_test): """训练并评估模型""" try: # 创建模型 model, scaler = cls.create_model(model_type) # 训练模型 model = cls.train_model(model, X_train, y_train, scaler, model_type) # 评估模型 accuracy = cls.evaluate_model(model, X_test, y_test, scaler, model_type) return model, scaler, accuracy except Exception as e: print(f"模型 {model_type} 训练/评估错误: {str(e)}") raise @classmethod def evaluate_all_models(cls, digits_dataset): """评估所有可用模型""" print("\n=== 模型评估 ===") # 划分数据集 X_train, X_test, y_train, y_test = cls.get_split_data(digits_dataset) results = [] # 存储结果 # 遍历所有模型 for model_type in MODEL_METADATA: name = MODEL_METADATA[model_type][0] print(f"评估模型: {name} ({model_type})") # 检查模型是否可用 if not MODEL_METADATA[model_type][1]: results.append({"模型名称": name, "准确率": "N/A"}) continue try: # 训练并评估模型 _, _, accuracy = cls.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) results.append({"模型名称": name, "准确率": f"{accuracy:.4f}"}) except Exception as e: results.append({"模型名称": name, "准确率": f"错误: {str(e)}"}) # 按准确率排序 results.sort( key=lambda x: float(x["准确率"]) if isinstance(x["准确率"], str) and x["准确率"].replace('.', '', 1).isdigit() else -1, reverse=True ) # 打印结果 print(pd.DataFrame(results)) return results # ======================== # 手写板类 - GUI界面和绘图功能 # ======================== class HandwritingBoard: CANVAS_SIZE = 300 # 固定画布尺寸 BRUSH_SIZE = 12 # 画笔大小 def __init__(self, root, model_factory, digits): # 初始化主窗口 self.root = root self.root.title("手写数字识别系统") self.root.geometry("1000x700") # 设置窗口大小 # 模型和数据相关 self.model_factory = model_factory # 模型工厂 self.digits = digits # 数字数据集 self.model_cache = {} # 模型缓存(提高切换速度) self.current_model = None # 当前使用的模型 self.scaler = None # 数据标准化器 self.current_model_type = None # 当前模型类型 self.has_drawn = False # 标记是否已绘制数字 self.custom_data = [] # 存储自定义训练数据 # 绘图相关状态 self.drawing = False # 是否正在绘制 self.last_x = self.last_y = 0 # 上次绘制位置 # 创建自定义数据目录 self.data_dir = "custom_digits_data" os.makedirs(self.data_dir, exist_ok=True) # 初始化画布(PIL图像) self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) # 创建白色背景图像 self.draw_obj = ImageDraw.Draw(self.image) # 创建绘图对象 # 创建界面组件 self.create_widgets() # 初始化默认模型 self.init_default_model() def create_widgets(self): """创建界面组件""" # 创建主架 main_frame = tk.Frame(self.root) main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) # 1. 模型选择区域 model_frame = tk.LabelFrame(main_frame, text="模型选择", font=("Arial", 10, "bold")) model_frame.grid(row=0, column=0, columnspan=2, sticky="ew", padx=5, pady=5) model_frame.grid_columnconfigure(1, weight=1) # 让模型标签可以扩展 # 模型选择标签 tk.Label(model_frame, text="选择模型:", font=("Arial", 10)).grid(row=0, column=0, padx=5, pady=5, sticky="w") # 获取可用模型列表 self.available_models = [] for model_type, (name, _, _, _) in MODEL_METADATA.items(): if MODEL_METADATA[model_type][1]: self.available_models.append((model_type, name)) # 模型选择下拉 self.model_var = tk.StringVar() self.model_combobox = ttk.Combobox( model_frame, textvariable=self.model_var, values=[name for _, name in self.available_models], state="readonly", width=25, font=("Arial", 10) ) self.model_combobox.current(0) # 设置默认选项 self.model_combobox.bind("<<ComboboxSelected>>", self.on_model_select) # 绑定选择事件 self.model_combobox.grid(row=0, column=1, padx=5, pady=5, sticky="ew") # 模型信息标签(显示准确率) self.model_label = tk.Label( model_frame, text="", font=("Arial", 10), relief=tk.SUNKEN, padx=5, pady=2 ) self.model_label.grid(row=0, column=2, padx=5, pady=5, sticky="ew") # 2. 左侧绘图区域和右侧结果区域 # 左侧绘图区域 left_frame = tk.LabelFrame(main_frame, text="绘制区域", font=("Arial", 10, "bold")) left_frame.grid(row=1, column=0, padx=5, pady=5, sticky="nsew") # 绘图画布 self.canvas = tk.Canvas(left_frame, bg="white", width=self.CANVAS_SIZE, height=self.CANVAS_SIZE) self.canvas.pack(padx=10, pady=10) # 绑定绘图事件 self.canvas.bind("<Button-1>", self.start_draw) # 鼠标按下 self.canvas.bind("<B1-Motion>", self.draw) # 鼠标拖动 self.canvas.bind("<ButtonRelease-1>", self.stop_draw) # 鼠标释放 # 添加绘制提示 self.canvas.create_text( self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2, text="绘制数字", fill="gray", font=("Arial", 16) ) # 绘图控制按钮 btn_frame = tk.Frame(left_frame) btn_frame.pack(fill=tk.X, pady=(0, 10)) # 功能按钮 tk.Button(btn_frame, text="识别", command=self.recognize, width=8).pack(side=tk.LEFT, padx=5) tk.Button(btn_frame, text="清除", command=self.clear_canvas, width=8).pack(side=tk.LEFT, padx=5) tk.Button(btn_frame, text="样本", command=self.show_samples, width=8).pack(side=tk.LEFT, padx=5) # 右侧结果区域 right_frame = tk.Frame(main_frame) right_frame.grid(row=1, column=1, padx=5, pady=5, sticky="nsew") # 2.1 识别结果区域 result_frame = tk.LabelFrame(right_frame, text="识别结果", font=("Arial", 10, "bold")) result_frame.pack(fill=tk.X, padx=5, pady=5) # 结果显示标签 self.result_label = tk.Label( result_frame, text="请绘制数字", font=("Arial", 24), pady=10 ) self.result_label.pack() # 置信度显示标签 self.prob_label = tk.Label( result_frame, text="", font=("Arial", 12) ) self.prob_label.pack() # 2.2 置信度可视化区域 confidence_frame = tk.LabelFrame(right_frame, text="识别置信度", font=("Arial", 10, "bold")) confidence_frame.pack(fill=tk.X, padx=5, pady=5) # 置信度画布(条形图) self.confidence_canvas = tk.Canvas( confidence_frame, bg="white", height=50 ) self.confidence_canvas.pack(fill=tk.X, padx=10, pady=10) self.confidence_canvas.create_text( 150, 25, text="识别后显示置信度", fill="gray", font=("Arial", 10) ) # 2.3 候选数字区域 candidates_frame = tk.LabelFrame(right_frame, text="可能的数字", font=("Arial", 10, "bold")) candidates_frame.pack(fill=tk.X, padx=5, pady=5) # 候选数字表格 columns = ("数字", "概率") self.candidates_tree = ttk.Treeview( candidates_frame, columns=columns, show="headings", height=4 ) # 配置表格列 for col in columns: self.candidates_tree.heading(col, text=col) self.candidates_tree.column(col, width=80, anchor=tk.CENTER) # 添加滚动条 scrollbar = ttk.Scrollbar( candidates_frame, orient=tk.VERTICAL, command=self.candidates_tree.yview ) self.candidates_tree.configure(yscroll=scrollbar.set) self.candidates_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5) scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5) # 3. 模型性能对比和训练集管理区域 # 3.1 模型性能对比区域 performance_frame = tk.LabelFrame(main_frame, text="模型性能对比", font=("Arial", 10, "bold")) performance_frame.grid(row=2, column=0, padx=5, pady=5, sticky="nsew") # 性能表格 columns = ("模型名称", "准确率") self.performance_tree = ttk.Treeview( performance_frame, columns=columns, show="headings", height=8 ) # 配置表格列 for col in columns: self.performance_tree.heading(col, text=col) self.performance_tree.column(col, width=120, anchor=tk.CENTER) # 添加滚动条 scrollbar = ttk.Scrollbar( performance_frame, orient=tk.VERTICAL, command=self.performance_tree.yview ) self.performance_tree.configure(yscroll=scrollbar.set) self.performance_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5) scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5) # 3.2 训练集管理区域 train_frame = tk.LabelFrame(main_frame, text="训练集管理", font=("Arial", 10, "bold")) train_frame.grid(row=2, column=1, padx=5, pady=5, sticky="nsew") # 训练集管理按钮 tk.Button( train_frame, text="保存为训练样本", command=self.save_as_training_sample, width=18, height=2 ).grid(row=0, column=0, padx=5, pady=5, sticky="ew") tk.Button( train_frame, text="保存全部训练集", command=self.save_all_training_data, width=18, height=2 ).grid(row=0, column=1, padx=5, pady=5, sticky="ew") tk.Button( train_frame, text="加载训练集", command=self.load_training_data, width=18, height=2 ).grid(row=1, column=0, padx=5, pady=5, sticky="ew") tk.Button( train_frame, text="性能图表", command=self.show_performance_chart, width=18, height=2 ).grid(row=1, column=1, padx=5, pady=5, sticky="ew") # 4. 状态栏 self.status_var = tk.StringVar(value="就绪") status_bar = tk.Label( self.root, textvariable=self.status_var, bd=1, relief=tk.SUNKEN, anchor=tk.W, font=("Arial", 10) ) status_bar.pack(side=tk.BOTTOM, fill=tk.X) # 配置布局权重 main_frame.grid_columnconfigure(0, weight=1) main_frame.grid_columnconfigure(1, weight=1) main_frame.grid_rowconfigure(1, weight=1) main_frame.grid_rowconfigure(2, weight=1) # ======================== # 绘图功能 # ======================== def start_draw(self, event): """开始绘制""" self.drawing = True self.last_x, self.last_y = event.x, event.y def draw(self, event): """绘制""" if not self.drawing: return x, y = event.x, event.y # 在画布上绘制 self.canvas.create_line( self.last_x, self.last_y, x, y, fill="black", width=self.BRUSH_SIZE, capstyle=tk.ROUND, smooth=True ) # 在图像上绘制(用于后续处理) self.draw_obj.line( [self.last_x, self.last_y, x, y], fill=0, # 黑色 width=self.BRUSH_SIZE ) # 更新位置 self.last_x, self.last_y = x, y def stop_draw(self, event): """停止绘制""" self.drawing = False self.has_drawn = True self.status_var.set("已绘制数字,点击'识别'进行识别") def clear_canvas(self): """清除画布""" # 清除画布内容 self.canvas.delete("all") # 重置图像 self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) # 白色背景 self.draw_obj = ImageDraw.Draw(self.image) # 添加绘制提示 self.canvas.create_text( self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2, text="绘制数字", fill="gray", font=("Arial", 16) ) # 重置结果显示 self.result_label.config(text="请绘制数字") self.prob_label.config(text="") self.clear_confidence_display() self.has_drawn = False self.status_var.set("画布已清除") def clear_confidence_display(self): """清除置信度显示""" self.confidence_canvas.delete("all") self.confidence_canvas.create_text( 150, 25, text="识别后显示置信度", fill="gray", font=("Arial", 10) ) # 清空候选数字表格 for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) # ======================== # 图像处理功能 # ======================== def preprocess_image(self): """预处理手写数字图像""" # 将PIL图像转换为NumPy数组 img_array = np.array(self.image) # 1. 高斯模糊降噪 img_array = cv2.GaussianBlur(img_array, (5, 5), 0) # 2. 二化(转换为黑白图像) _, img_array = cv2.threshold(img_array, 127, 255, cv2.THRESH_BINARY_INV) # 3. 轮廓检测(查找数字轮廓) contours, _ = cv2.findContours(img_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: self.status_var.set("未检测到有效数字,请重新绘制") return None # 4. 找到最大轮廓(即数字部分) c = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(c) # 5. 提取数字区域 digit = img_array[y:y+h, x:x+w] # 6. 填充为正方形(保持长宽比) size = max(w, h) padded = np.ones((size, size), dtype=np.uint8) * 255 # 白色背景 offset_x = (size - w) // 2 offset_y = (size - h) // 2 padded[offset_y:offset_y+h, offset_x:offset_x+w] = digit # 7. 缩放为8x8(匹配MNIST数据集格式) resized = cv2.resize(padded, (8, 8), interpolation=cv2.INTER_AREA) # 8. 归一化(将像素从0-255映射到0-16) normalized = 16 - (resized / 255 * 16).astype(np.uint8) # 9. 展平为一维数组(64个特征) return normalized.flatten() # ======================== # 识别功能 # ======================== def recognize(self): """识别手写数字""" # 检查是否已绘制数字 if not self.has_drawn: self.status_var.set("请先绘制数字再识别") return # 检查模型是否已加载 if self.current_model is None: self.status_var.set("模型未加载,请选择模型") return # 预处理图像 img_array = self.preprocess_image() if img_array is None: return # 重塑为模型输入格式(1个样本,64个特征) img_input = img_array.reshape(1, -1) try: # 数据标准化 if self.scaler: img_input = self.scaler.transform(img_input) # LightGBM特殊处理(需要DataFrame格式) if self.current_model_type == 'lgb' and hasattr(self.current_model, 'feature_name_'): img_input = pd.DataFrame(img_input, columns=self.current_model.feature_name_) # 预测数字 pred = self.current_model.predict(img_input)[0] self.result_label.config(text=f"识别结果: {pred}") # 概率预测(如果模型支持) if hasattr(self.current_model, 'predict_proba'): probs = self.current_model.predict_proba(img_input)[0] confidence = probs[pred] # 预测结果的置信度 # 更新UI self.prob_label.config(text=f"置信度: {confidence:.2%}") self.update_confidence_display(confidence) # 更新置信度可视化 # 显示候选数字(概率最高的3个) top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3] self.update_candidates_display(top3) else: self.prob_label.config(text="该模型不支持概率输出") self.clear_confidence_display() self.status_var.set(f"识别完成: 数字 {pred}") except Exception as e: self.status_var.set(f"识别错误: {str(e)}") self.clear_confidence_display() # ======================== # UI更新功能 # ======================== def update_confidence_display(self, confidence): """更新置信度可视化""" self.confidence_canvas.delete("all") # 获取画布宽度 canvas_width = self.confidence_canvas.winfo_width() or 300 # 绘制背景 self.confidence_canvas.create_rectangle( 10, 10, canvas_width - 10, 40, fill="#f0f0f0", outline="#cccccc" ) # 绘制置信度条(根据置信度) bar_width = int((canvas_width - 20) * confidence) color = self.get_confidence_color(confidence) # 根据置信度选择颜色 self.confidence_canvas.create_rectangle( 10, 10, 10 + bar_width, 40, fill=color, outline="" ) # 绘制文本(显示百分比) self.confidence_canvas.create_text( canvas_width / 2, 25, text=f"{confidence:.1%}", font=("Arial", 10, "bold") ) # 绘制刻度 for i in range(0, 11): x_pos = 10 + i * (canvas_width - 20) / 10 self.confidence_canvas.create_line(x_pos, 40, x_pos, 45, width=1) if i % 2 == 0: self.confidence_canvas.create_text(x_pos, 55, text=f"{i*10}%", font=("Arial", 8)) def get_confidence_color(self, confidence): """根据置信度获取颜色""" # 高置信度:绿色 if confidence >= 0.9: return "#4CAF50" # 中等置信度:黄色 elif confidence >= 0.7: return "#FFC107" # 低置信度:红色 else: return "#F44336" def update_candidates_display(self, candidates): """更新候选数字显示""" # 清空现有项 for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) # 添加新项(候选数字及其概率) for digit, prob in candidates: self.candidates_tree.insert( "", tk.END, values=(digit, f"{prob:.2%}") ) # ======================== # 样本显示功能 # ======================== def show_samples(self): """显示样本图像""" plt.figure(figsize=(10, 4)) # 显示0-9每个数字的一个样本 for i in range(10): plt.subplot(2, 5, i+1) sample_idx = np.where(self.digits.target == i)[0][0] plt.imshow(self.digits.images[sample_idx], cmap="gray") plt.title(f"数字 {i}", fontsize=9) plt.axis("off") plt.tight_layout() plt.show() # ======================== # 模型管理功能 # ======================== def on_model_select(self, event): """模型选择事件处理""" # 获取选中的模型名称 selected_name = self.model_var.get() # 查找对应的模型类型 model_type = next( (k for k, v in self.available_models if v == selected_name), None ) if model_type: # 切换模型 self.change_model(model_type) def change_model(self, model_type): """切换模型""" model_name = MODEL_METADATA[model_type][0] # 尝试从缓存加载模型 if model_type in self.model_cache: self.current_model, self.scaler, accuracy, self.current_model_type = self.model_cache[model_type] self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})") self.status_var.set(f"已加载模型: {model_name}") return # 加载新模型 self.status_var.set(f"正在加载模型: {model_name}...") self.root.update() # 更新UI显示状态 try: # 获取数据集 X_train, X_test, y_train, y_test = self.model_factory.get_split_data(self.digits) # 训练并评估模型 self.current_model, self.scaler, accuracy = self.model_factory.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) # 缓存模型 self.current_model_type = model_type self.model_cache[model_type] = (self.current_model, self.scaler, accuracy, self.current_model_type) # 更新UI self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})") self.status_var.set(f"模型加载完成: {model_name}, 准确率: {accuracy:.4f}") self.clear_canvas() # 更新性能表格 self.load_performance_data() except Exception as e: self.status_var.set(f"模型加载失败: {str(e)}") self.model_label.config(text="模型加载失败") def init_default_model(self): """初始化默认模型""" # 设置默认模型并加载 self.model_var.set(self.available_models[0][1]) self.change_model(self.available_models[0][0]) def load_performance_data(self): """加载性能数据""" # 评估所有模型 results = self.model_factory.evaluate_all_models(self.digits) # 清空表格 for item in self.performance_tree.get_children(): self.performance_tree.delete(item) # 添加数据到表格 for i, result in enumerate(results): tag = "highlight" if i == 0 else "" # 高亮显示性能最好的模型 self.performance_tree.insert( "", tk.END, values=(result["模型名称"], result["准确率"]), tags=(tag,) ) # 配置高亮样式 self.performance_tree.tag_configure("highlight", background="#e6f7ff") # ======================== # 性能可视化功能 # ======================== def show_performance_chart(self): """显示性能图表""" # 获取性能数据 results = self.model_factory.evaluate_all_models(self.digits) # 提取有效结果(过滤掉错误数据) valid_results = [] for result in results: try: accuracy = float(result["准确率"]) valid_results.append((result["模型名称"], accuracy)) except ValueError: continue if not valid_results: messagebox.showinfo("提示", "没有可用的性能数据") return # 按准确率排序 valid_results.sort(key=lambda x: x[1], reverse=True) models, accuracies = zip(*valid_results) # 创建水平条形图 plt.figure(figsize=(10, 5)) bars = plt.barh(models, accuracies, color='#2196F3') plt.xlabel('准确率', fontsize=10) plt.ylabel('模型', fontsize=10) plt.title('模型性能对比', fontsize=12) plt.xlim(0, 1.05) # 设置X轴范围 # 添加数标签 for bar in bars: width = bar.get_width() plt.text( width + 0.01, bar.get_y() + bar.get_height()/2, f'{width:.4f}', ha='left', va='center', fontsize=8 ) plt.tight_layout() plt.show() # ======================== # 训练集管理功能 # ======================== def save_as_training_sample(self): """保存为训练样本""" # 检查是否已绘制数字 if not self.has_drawn: self.status_var.set("请先绘制数字再保存") return # 预处理图像 img_array = self.preprocess_image() if img_array is None: return # 弹出标签输入窗口 label_window = tk.Toplevel(self.root) label_window.title("输入标签") label_window.geometry("300x150") label_window.transient(self.root) label_window.grab_set() # 模态窗口 # 标签输入提示 tk.Label( label_window, text="请输入数字标签 (0-9):", font=("Arial", 10) ).pack(pady=10) # 输入框 entry = tk.Entry(label_window, font=("Arial", 12), width=5) entry.pack(pady=5) entry.focus_set() def save_with_label(): """保存带标签的样本""" try: # 验证标签 label = int(entry.get()) if label < 0 or label > 9: raise ValueError("标签必须是0-9的数字") # 添加到自定义数据集 self.custom_data.append((img_array.tolist(), label)) self.status_var.set(f"已保存数字 {label} (共 {len(self.custom_data)} 个样本)") label_window.destroy() except ValueError as e: self.status_var.set(f"保存错误: {str(e)}") # 保存按钮 tk.Button( label_window, text="保存", command=save_with_label, width=10 ).pack(pady=5) def save_all_training_data(self): """保存全部训练数据""" # 检查是否有数据可保存 if not self.custom_data: self.status_var.set("没有训练数据可保存") return # 弹出文件保存对话 file_path = filedialog.asksaveasfilename( defaultextension=".csv", filetypes=[("CSV文件", "*.csv")], initialfile="custom_digits.csv", title="保存训练集" ) if not file_path: return try: # 写入CSV文件 with open(file_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) # 写入表头(64个像素+标签) writer.writerow([f'pixel{i}' for i in range(64)] + ['label']) # 写入数据 for img_data, label in self.custom_data: writer.writerow(img_data + [label]) self.status_var.set(f"已保存 {len(self.custom_data)} 个样本到 {os.path.basename(file_path)}") except Exception as e: self.status_var.set(f"保存失败: {str(e)}") def load_training_data(self): """加载训练数据""" # 弹出文件选择对话 file_path = filedialog.askopenfilename( filetypes=[("CSV文件", "*.csv")], title="加载训练集" ) if not file_path: return try: self.custom_data = [] # 读取CSV文件 with open(file_path, 'r', newline='', encoding='utf-8') as f: reader = csv.reader(f) next(reader) # 跳过标题行 # 解析每一行数据 for row in reader: if len(row) != 65: # 64像素+1标签 continue # 提取像素数据和标签 img_data = [float(pixel) for pixel in row[:64]] label = int(row[64]) self.custom_data.append((img_data, label)) self.status_var.set(f"已加载 {len(self.custom_data)} 个样本") except Exception as e: self.status_var.set(f"加载失败: {str(e)}") # ======================== # 主程序入口 # ======================== def run(self): """运行应用""" self.root.mainloop() # ======================== # 程序入口 # ======================== if __name__ == "__main__": digits = load_digits() # 加载数字数据集 root = tk.Tk() # 创建主窗口 app = HandwritingBoard(root, ModelFactory, digits) # 创建应用实例 app.run() # 运行应用 请你根据上面的内容生成符合要求的课程设计:
最新发布
06-24
import numpy as np import matplotlib.pyplot as plt import pandas as pd import tkinter as tk from tkinter import ttk, filedialog, messagebox from PIL import Image, ImageDraw import cv2 import os import csv from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split from sklearn.svm import SVC from sklearn.tree import DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.neural_network import MLPClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.naive_bayes import GaussianNB from sklearn.metrics import accuracy_score from sklearn.preprocessing import StandardScaler # 设置中文字体和负号显示 plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"] plt.rcParams["axes.unicode_minus"] = False # 尝试导入XGBoost和LightGBM XGB_INSTALLED = False LGB_INSTALLED = False try: import xgboost as xgb XGB_INSTALLED = True except ImportError: print("警告: 未安装XGBoost库,无法使用XGBoost模型") try: import lightgbm as lgb LGB_INSTALLED = True except ImportError: print("警告: 未安装LightGBM库,无法使用LightGBM模型") # 定义模型元数据常量(优化参数) MODEL_METADATA = { 'svm': ('支持向量机(SVM)', SVC, StandardScaler, {'probability': True, 'random_state': 42}), 'dt': ('决策树(DT)', DecisionTreeClassifier, None, {'random_state': 42}), 'rf': ('随机森林(RF)', RandomForestClassifier, None, {'n_estimators': 100, 'random_state': 42}), 'mlp': ('多层感知机(MLP)', MLPClassifier, StandardScaler, {'hidden_layer_sizes': (100, 50), 'max_iter': 500, 'random_state': 42}), 'knn': ('K最近邻(KNN)', KNeighborsClassifier, StandardScaler, {'n_neighbors': 5, 'weights': 'distance'}), 'nb': ('高斯朴素贝叶斯(NB)', GaussianNB, None, {}), } # 添加可选模型 if XGB_INSTALLED: MODEL_METADATA['xgb'] = ('XGBoost(XGB)', xgb.XGBClassifier, None, {'objective': 'multi:softmax', 'random_state': 42}) if LGB_INSTALLED: MODEL_METADATA['lgb'] = ('LightGBM(LGB)', lgb.LGBMClassifier, None, { 'objective': 'multiclass', 'random_state': 42, 'num_class': 10, 'max_depth': 5, 'min_child_samples': 10, 'learning_rate': 0.1, 'force_col_wise': True }) class ModelFactory: @staticmethod def get_split_data(digits_dataset): """数据集划分""" X, y = digits_dataset.data, digits_dataset.target return train_test_split(X, y, test_size=0.3, random_state=42) @classmethod def create_model(cls, model_type): """创建模型和数据标准化器""" if model_type not in MODEL_METADATA: raise ValueError(f"未知模型类型: {model_type}") name, model_cls, scaler_cls, params = MODEL_METADATA[model_type] if not model_cls: raise ImportError(f"{name}模型依赖库未安装") model = model_cls(**params) scaler = scaler_cls() if scaler_cls else None return model, scaler @staticmethod def train_model(model, X_train, y_train, scaler=None, model_type=None): """训练模型""" if scaler: X_train = scaler.fit_transform(X_train) if model_type == 'lgb' and isinstance(X_train, np.ndarray): X_train = pd.DataFrame(X_train) model.fit(X_train, y_train) return model @staticmethod def evaluate_model(model, X_test, y_test, scaler=None, model_type=None): """评估模型""" if scaler: X_test = scaler.transform(X_test) if model_type == 'lgb' and isinstance(X_test, np.ndarray) and hasattr(model, 'feature_name_'): X_test = pd.DataFrame(X_test, columns=model.feature_name_) y_pred = model.predict(X_test) return accuracy_score(y_test, y_pred) @classmethod def train_and_evaluate(cls, model_type, X_train, y_train, X_test, y_test): """训练并评估模型""" try: model, scaler = cls.create_model(model_type) model = cls.train_model(model, X_train, y_train, scaler, model_type) accuracy = cls.evaluate_model(model, X_test, y_test, scaler, model_type) return model, scaler, accuracy except Exception as e: print(f"模型 {model_type} 训练/评估错误: {str(e)}") raise @classmethod def evaluate_all_models(cls, digits_dataset): """评估所有可用模型""" print("\n=== 模型评估 ===") X_train, X_test, y_train, y_test = cls.get_split_data(digits_dataset) results = [] for model_type in MODEL_METADATA: name = MODEL_METADATA[model_type][0] print(f"评估模型: {name} ({model_type})") if not MODEL_METADATA[model_type][1]: results.append({"模型名称": name, "准确率": "N/A"}) continue try: _, _, accuracy = cls.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) results.append({"模型名称": name, "准确率": f"{accuracy:.4f}"}) except Exception as e: results.append({"模型名称": name, "准确率": f"错误: {str(e)}"}) # 按准确率排序 results.sort( key=lambda x: float(x["准确率"]) if isinstance(x["准确率"], str) and x["准确率"].replace('.', '', 1).isdigit() else -1, reverse=True ) print(pd.DataFrame(results)) return results class HandwritingBoard: CANVAS_SIZE = 300 # 固定画布尺寸 BRUSH_SIZE = 12 # 画笔大小 def __init__(self, root, model_factory, digits): self.root = root self.root.title("手写数字识别系统") self.root.geometry("1000x700") # 增加窗口尺寸以容纳所有组件 self.model_factory = model_factory self.digits = digits self.model_cache = {} self.current_model = None self.scaler = None self.current_model_type = None self.has_drawn = False self.custom_data = [] self.drawing = False self.last_x = self.last_y = 0 # 自定义数据目录 self.data_dir = "custom_digits_data" os.makedirs(self.data_dir, exist_ok=True) # 初始化画布 self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) self.draw_obj = ImageDraw.Draw(self.image) self.create_widgets() self.init_default_model() def create_widgets(self): """使用grid布局管理器创建界面组件""" # 创建主架 main_frame = tk.Frame(self.root) main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) # 使用grid布局管理器 # 第一行:模型选择区域 model_frame = tk.LabelFrame(main_frame, text="模型选择", font=("Arial", 10, "bold")) model_frame.grid(row=0, column=0, columnspan=2, sticky="ew", padx=5, pady=5) model_frame.grid_columnconfigure(1, weight=1) # 让模型标签可以扩展 tk.Label(model_frame, text="选择模型:", font=("Arial", 10)).grid(row=0, column=0, padx=5, pady=5, sticky="w") self.available_models = [] for model_type, (name, _, _, _) in MODEL_METADATA.items(): if MODEL_METADATA[model_type][1]: self.available_models.append((model_type, name)) self.model_var = tk.StringVar() self.model_combobox = ttk.Combobox( model_frame, textvariable=self.model_var, values=[name for _, name in self.available_models], state="readonly", width=25, font=("Arial", 10) ) self.model_combobox.current(0) self.model_combobox.bind("<<ComboboxSelected>>", self.on_model_select) self.model_combobox.grid(row=0, column=1, padx=5, pady=5, sticky="ew") self.model_label = tk.Label( model_frame, text="", font=("Arial", 10), relief=tk.SUNKEN, padx=5, pady=2 ) self.model_label.grid(row=0, column=2, padx=5, pady=5, sticky="ew") # 第二行:左侧绘图区域和右侧结果区域 # 左侧绘图区域 left_frame = tk.LabelFrame(main_frame, text="绘制区域", font=("Arial", 10, "bold")) left_frame.grid(row=1, column=0, padx=5, pady=5, sticky="nsew") self.canvas = tk.Canvas(left_frame, bg="white", width=self.CANVAS_SIZE, height=self.CANVAS_SIZE) self.canvas.pack(padx=10, pady=10) self.canvas.bind("<Button-1>", self.start_draw) self.canvas.bind("<B1-Motion>", self.draw) self.canvas.bind("<ButtonRelease-1>", self.stop_draw) # 添加绘制提示 self.canvas.create_text( self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2, text="绘制数字", fill="gray", font=("Arial", 16) ) # 绘图控制按钮 btn_frame = tk.Frame(left_frame) btn_frame.pack(fill=tk.X, pady=(0, 10)) tk.Button(btn_frame, text="识别", command=self.recognize, width=8).pack(side=tk.LEFT, padx=5) tk.Button(btn_frame, text="清除", command=self.clear_canvas, width=8).pack(side=tk.LEFT, padx=5) tk.Button(btn_frame, text="样本", command=self.show_samples, width=8).pack(side=tk.LEFT, padx=5) # 右侧结果区域 right_frame = tk.Frame(main_frame) right_frame.grid(row=1, column=1, padx=5, pady=5, sticky="nsew") # 识别结果 result_frame = tk.LabelFrame(right_frame, text="识别结果", font=("Arial", 10, "bold")) result_frame.pack(fill=tk.X, padx=5, pady=5) self.result_label = tk.Label( result_frame, text="请绘制数字", font=("Arial", 24), pady=10 ) self.result_label.pack() self.prob_label = tk.Label( result_frame, text="", font=("Arial", 12) ) self.prob_label.pack() # 置信度可视化 confidence_frame = tk.LabelFrame(right_frame, text="识别置信度", font=("Arial", 10, "bold")) confidence_frame.pack(fill=tk.X, padx=5, pady=5) self.confidence_canvas = tk.Canvas( confidence_frame, bg="white", height=50 ) self.confidence_canvas.pack(fill=tk.X, padx=10, pady=10) self.confidence_canvas.create_text( 150, 25, text="识别后显示置信度", fill="gray", font=("Arial", 10) ) # 候选数字 candidates_frame = tk.LabelFrame(right_frame, text="可能的数字", font=("Arial", 10, "bold")) candidates_frame.pack(fill=tk.X, padx=5, pady=5) columns = ("数字", "概率") self.candidates_tree = ttk.Treeview( candidates_frame, columns=columns, show="headings", height=4 ) for col in columns: self.candidates_tree.heading(col, text=col) self.candidates_tree.column(col, width=80, anchor=tk.CENTER) scrollbar = ttk.Scrollbar( candidates_frame, orient=tk.VERTICAL, command=self.candidates_tree.yview ) self.candidates_tree.configure(yscroll=scrollbar.set) self.candidates_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5) scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5) # 第三行:模型性能对比和训练集管理 # 模型性能对比 performance_frame = tk.LabelFrame(main_frame, text="模型性能对比", font=("Arial", 10, "bold")) performance_frame.grid(row=2, column=0, padx=5, pady=5, sticky="nsew") columns = ("模型名称", "准确率") self.performance_tree = ttk.Treeview( performance_frame, columns=columns, show="headings", height=8 ) for col in columns: self.performance_tree.heading(col, text=col) self.performance_tree.column(col, width=120, anchor=tk.CENTER) scrollbar = ttk.Scrollbar( performance_frame, orient=tk.VERTICAL, command=self.performance_tree.yview ) self.performance_tree.configure(yscroll=scrollbar.set) self.performance_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5) scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5) # 训练集管理 train_frame = tk.LabelFrame(main_frame, text="训练集管理", font=("Arial", 10, "bold")) train_frame.grid(row=2, column=1, padx=5, pady=5, sticky="nsew") # 使用grid布局训练集管理按钮 tk.Button( train_frame, text="保存为训练样本", command=self.save_as_training_sample, width=18, height=2 ).grid(row=0, column=0, padx=5, pady=5, sticky="ew") tk.Button( train_frame, text="保存全部训练集", command=self.save_all_training_data, width=18, height=2 ).grid(row=0, column=1, padx=5, pady=5, sticky="ew") tk.Button( train_frame, text="加载训练集", command=self.load_training_data, width=18, height=2 ).grid(row=1, column=0, padx=5, pady=5, sticky="ew") tk.Button( train_frame, text="性能图表", command=self.show_performance_chart, width=18, height=2 ).grid(row=1, column=1, padx=5, pady=5, sticky="ew") # 状态信息 self.status_var = tk.StringVar(value="就绪") status_bar = tk.Label( self.root, textvariable=self.status_var, bd=1, relief=tk.SUNKEN, anchor=tk.W, font=("Arial", 10) ) status_bar.pack(side=tk.BOTTOM, fill=tk.X) # 配置权重 main_frame.grid_columnconfigure(0, weight=1) main_frame.grid_columnconfigure(1, weight=1) main_frame.grid_rowconfigure(1, weight=1) main_frame.grid_rowconfigure(2, weight=1) def start_draw(self, event): """开始绘制""" self.drawing = True self.last_x, self.last_y = event.x, event.y def draw(self, event): """绘制""" if not self.drawing: return x, y = event.x, event.y # 在画布上绘制 self.canvas.create_line( self.last_x, self.last_y, x, y, fill="black", width=self.BRUSH_SIZE, capstyle=tk.ROUND, smooth=True ) # 在图像上绘制 self.draw_obj.line( [self.last_x, self.last_y, x, y], fill=0, width=self.BRUSH_SIZE ) self.last_x, self.last_y = x, y def stop_draw(self, event): """停止绘制""" self.drawing = False self.has_drawn = True self.status_var.set("已绘制数字,点击'识别'进行识别") def clear_canvas(self): """清除画布""" self.canvas.delete("all") self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) self.draw_obj = ImageDraw.Draw(self.image) # 添加绘制提示 self.canvas.create_text( self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2, text="绘制数字", fill="gray", font=("Arial", 16) ) self.result_label.config(text="请绘制数字") self.prob_label.config(text="") self.clear_confidence_display() self.has_drawn = False self.status_var.set("画布已清除") def clear_confidence_display(self): """清除置信度显示""" self.confidence_canvas.delete("all") self.confidence_canvas.create_text( 150, 25, text="识别后显示置信度", fill="gray", font=("Arial", 10) ) for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) def preprocess_image(self): """预处理手写数字图像""" img_array = np.array(self.image) # 高斯模糊降噪 img_array = cv2.GaussianBlur(img_array, (5, 5), 0) # 二化 _, img_array = cv2.threshold(img_array, 127, 255, cv2.THRESH_BINARY_INV) # 轮廓检测 contours, _ = cv2.findContours(img_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: self.status_var.set("未检测到有效数字,请重新绘制") return None # 找到最大轮廓 c = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(c) # 提取数字区域 digit = img_array[y:y+h, x:x+w] # 填充为正方形 size = max(w, h) padded = np.ones((size, size), dtype=np.uint8) * 255 offset_x = (size - w) // 2 offset_y = (size - h) // 2 padded[offset_y:offset_y+h, offset_x:offset_x+w] = digit # 缩放为8x8 resized = cv2.resize(padded, (8, 8), interpolation=cv2.INTER_AREA) # 归一化 normalized = 16 - (resized / 255 * 16).astype(np.uint8) return normalized.flatten() def recognize(self): """识别手写数字""" if not self.has_drawn: self.status_var.set("请先绘制数字再识别") return if self.current_model is None: self.status_var.set("模型未加载,请选择模型") return # 预处理图像 img_array = self.preprocess_image() if img_array is None: return img_input = img_array.reshape(1, -1) try: # 标准化 if self.scaler: img_input = self.scaler.transform(img_input) # LightGBM特殊处理 if self.current_model_type == 'lgb' and hasattr(self.current_model, 'feature_name_'): img_input = pd.DataFrame(img_input, columns=self.current_model.feature_name_) # 预测 pred = self.current_model.predict(img_input)[0] self.result_label.config(text=f"识别结果: {pred}") # 概率预测 if hasattr(self.current_model, 'predict_proba'): probs = self.current_model.predict_proba(img_input)[0] confidence = probs[pred] # 更新UI self.prob_label.config(text=f"置信度: {confidence:.2%}") self.update_confidence_display(confidence) # 显示候选数字 top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3] self.update_candidates_display(top3) else: self.prob_label.config(text="该模型不支持概率输出") self.clear_confidence_display() self.status_var.set(f"识别完成: 数字 {pred}") except Exception as e: self.status_var.set(f"识别错误: {str(e)}") self.clear_confidence_display() def update_confidence_display(self, confidence): """更新置信度可视化""" self.confidence_canvas.delete("all") # 画布尺寸 canvas_width = self.confidence_canvas.winfo_width() or 300 # 绘制背景 self.confidence_canvas.create_rectangle( 10, 10, canvas_width - 10, 40, fill="#f0f0f0", outline="#cccccc" ) # 绘制置信度条 bar_width = int((canvas_width - 20) * confidence) color = self.get_confidence_color(confidence) self.confidence_canvas.create_rectangle( 10, 10, 10 + bar_width, 40, fill=color, outline="" ) # 绘制文本 self.confidence_canvas.create_text( canvas_width / 2, 25, text=f"{confidence:.1%}", font=("Arial", 10, "bold") ) # 绘制刻度 for i in range(0, 11): x_pos = 10 + i * (canvas_width - 20) / 10 self.confidence_canvas.create_line(x_pos, 40, x_pos, 45, width=1) if i % 2 == 0: self.confidence_canvas.create_text(x_pos, 55, text=f"{i*10}%", font=("Arial", 8)) def get_confidence_color(self, confidence): """根据置信度获取颜色""" if confidence >= 0.9: return "#4CAF50" # 绿色 elif confidence >= 0.7: return "#FFC107" # 黄色 else: return "#F44336" # 红色 def update_candidates_display(self, candidates): """更新候选数字显示""" # 清空现有项 for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) # 添加新项 for digit, prob in candidates: self.candidates_tree.insert( "", tk.END, values=(digit, f"{prob:.2%}") ) def show_samples(self): """显示样本图像""" plt.figure(figsize=(10, 4)) for i in range(10): plt.subplot(2, 5, i+1) sample_idx = np.where(self.digits.target == i)[0][0] plt.imshow(self.digits.images[sample_idx], cmap="gray") plt.title(f"数字 {i}", fontsize=9) plt.axis("off") plt.tight_layout() plt.show() def on_model_select(self, event): """模型选择事件处理""" selected_name = self.model_var.get() model_type = next( (k for k, v in self.available_models if v == selected_name), None ) if model_type: self.change_model(model_type) def change_model(self, model_type): """切换模型""" model_name = MODEL_METADATA[model_type][0] # 从缓存加载 if model_type in self.model_cache: self.current_model, self.scaler, accuracy, self.current_model_type = self.model_cache[model_type] self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})") self.status_var.set(f"已加载模型: {model_name}") return self.status_var.set(f"正在加载模型: {model_name}...") self.root.update() # 更新UI显示状态 try: X_train, X_test, y_train, y_test = self.model_factory.get_split_data(self.digits) self.current_model, self.scaler, accuracy = self.model_factory.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) self.current_model_type = model_type self.model_cache[model_type] = (self.current_model, self.scaler, accuracy, self.current_model_type) self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})") self.status_var.set(f"模型加载完成: {model_name}, 准确率: {accuracy:.4f}") self.clear_canvas() # 更新性能表格 self.load_performance_data() except Exception as e: self.status_var.set(f"模型加载失败: {str(e)}") self.model_label.config(text="模型加载失败") def init_default_model(self): """初始化默认模型""" self.model_var.set(self.available_models[0][1]) self.change_model(self.available_models[0][0]) def load_performance_data(self): """加载性能数据""" results = self.model_factory.evaluate_all_models(self.digits) # 清空表格 for item in self.performance_tree.get_children(): self.performance_tree.delete(item) # 添加数据 for i, result in enumerate(results): tag = "highlight" if i == 0 else "" self.performance_tree.insert( "", tk.END, values=(result["模型名称"], result["准确率"]), tags=(tag,) ) self.performance_tree.tag_configure("highlight", background="#e6f7ff") def show_performance_chart(self): """显示性能图表""" results = self.model_factory.evaluate_all_models(self.digits) # 提取有效结果 valid_results = [] for result in results: try: accuracy = float(result["准确率"]) valid_results.append((result["模型名称"], accuracy)) except ValueError: continue if not valid_results: messagebox.showinfo("提示", "没有可用的性能数据") return # 排序 valid_results.sort(key=lambda x: x[1], reverse=True) models, accuracies = zip(*valid_results) # 创建图表 plt.figure(figsize=(10, 5)) bars = plt.barh(models, accuracies, color='#2196F3') plt.xlabel('准确率', fontsize=10) plt.ylabel('模型', fontsize=10) plt.title('模型性能对比', fontsize=12) plt.xlim(0, 1.05) # 添加数标签 for bar in bars: width = bar.get_width() plt.text( width + 0.01, bar.get_y() + bar.get_height()/2, f'{width:.4f}', ha='left', va='center', fontsize=8 ) plt.tight_layout() plt.show() def save_as_training_sample(self): """保存为训练样本""" if not self.has_drawn: self.status_var.set("请先绘制数字再保存") return img_array = self.preprocess_image() if img_array is None: return # 弹出标签输入窗口 label_window = tk.Toplevel(self.root) label_window.title("输入标签") label_window.geometry("300x150") label_window.transient(self.root) label_window.grab_set() tk.Label( label_window, text="请输入数字标签 (0-9):", font=("Arial", 10) ).pack(pady=10) entry = tk.Entry(label_window, font=("Arial", 12), width=5) entry.pack(pady=5) entry.focus_set() def save_with_label(): try: label = int(entry.get()) if label < 0 or label > 9: raise ValueError("标签必须是0-9的数字") self.custom_data.append((img_array.tolist(), label)) self.status_var.set(f"已保存数字 {label} (共 {len(self.custom_data)} 个样本)") label_window.destroy() except ValueError as e: self.status_var.set(f"保存错误: {str(e)}") tk.Button( label_window, text="保存", command=save_with_label, width=10 ).pack(pady=5) def save_all_training_data(self): """保存全部训练数据""" if not self.custom_data: self.status_var.set("没有训练数据可保存") return file_path = filedialog.asksaveasfilename( defaultextension=".csv", filetypes=[("CSV文件", "*.csv")], initialfile="custom_digits.csv", title="保存训练集" ) if not file_path: return try: with open(file_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) writer.writerow([f'pixel{i}' for i in range(64)] + ['label']) for img_data, label in self.custom_data: writer.writerow(img_data + [label]) self.status_var.set(f"已保存 {len(self.custom_data)} 个样本到 {os.path.basename(file_path)}") except Exception as e: self.status_var.set(f"保存失败: {str(e)}") def load_training_data(self): """加载训练数据""" file_path = filedialog.askopenfilename( filetypes=[("CSV文件", "*.csv")], title="加载训练集" ) if not file_path: return try: self.custom_data = [] with open(file_path, 'r', newline='', encoding='utf-8') as f: reader = csv.reader(f) next(reader) # 跳过标题 for row in reader: if len(row) != 65: continue img_data = [float(pixel) for pixel in row[:64]] label = int(row[64]) self.custom_data.append((img_data, label)) self.status_var.set(f"已加载 {len(self.custom_data)} 个样本") except Exception as e: self.status_var.set(f"加载失败: {str(e)}") def run(self): """运行应用""" self.root.mainloop() if __name__ == "__main__": digits = load_digits() root = tk.Tk() app = HandwritingBoard(root, ModelFactory, digits) app.run() 基于此代码,在其中做好大量注释,同时要明确代码的分区功能,要显示明白,让刚学python的同学要能看懂。
06-24
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

【拾光静好 微微一笑】

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值