《Learning OpenCV3》——第六章 绘图和注释

本文介绍OpenCV3中的绘图功能,包括线、多边形等基本图形的绘制方法及其参数设置,并探讨了如何利用OpenCV3在图像上绘制文本。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

第六章 绘图和注释

OpenCV3提供在图像上绘图的功能。通常情况下,绘图操作涉及单通道(灰度图像)和三通道(彩色图像)操作,Alpha通道的绘制暂时不支持。此外,OpenCV3组织彩色图像的方式为BGR排列,而不是常见的RGB排列顺序。

一:线和填充多边形

线绘制的时候往往会涉及两个参数: thickness和lineType 。一般lineType取值有三种:4,8或者cv::LINE_AA;thickness为线的宽度,对于圆形、矩形或者一些其他的封闭多边形,thickness可以设置为cv::FILLED,此时将对多边形使用边界颜色进行填充。
这里写图片描述
一般来说,绘图的起始点、结束点、角点等参数都为整数类型,但OpenCV3绘制函数支持的 shitf 参数可以进行亚像素点即非整数点的绘制。一般在绘制函数中传入shift值作为小数位数使用。
绘图功能:

函数描述
cv::circle()Draw a simple circle
cv::clipLine()Determine if a line is inside a given box
cv::ellipse()Draw an ellipse, which may be tilted or an elliptical arc
cv::ellipse2Poly()Compute a polygon approximation to an elliptical arc
cv::fillConvexPoly()Draw filled versions of simple polygons
cv::fillPoly()Draw filled versions of arbitrary polygons
cv::line()Draw a simple line

二:字体和文本

文本绘制功能:

函数描述
cv::putText()Draw the specified text in an image
cv::getTextSize()Determine the width and height of a text string
课程设计与要求: 实验18 手写数字识别程序设计与实现 实验类型:设计性实验 实验学时:8 涉及的知识点:SVM、决策树、随机森林、XGBoostLightGBM机器学习算法的综合应用 一、 实验目的 1、 了解机器学习算法应用项目设计流程与基本方法。 2、 掌握SVM应用设计与K折交叉验证法获得测试数据。 3、 熟悉两种以上不同类型机器学习算法及应用。 4、 掌握各类机器学习算法的区别、优缺点;会应用网格搜索选择最优超参数。 5、 掌握分类任务的性能指标评价方法。 二、 实验要求 1、 使用anaconda集成开发环境完成课程设计,代码的可维护性好,有必要的注释相应的文档。 2、 能够识别符合分辨率要求的手写数字。 3、 构建不同模型实现手写数字分类识别,至少要对比两种方法,如决策树、支持向量机、随机森林、XGBoostLightGBM等。对比不同模型的分类性能报告,评价模型好坏。 4、 数据集采用sklearn.datasets中的digits,测试集数据可以用自己手写产生或者从digits中拆分。 三、 设计指标 1、 完整的设计文档 1) 系统的需求分析 2) 系统的概要设计 3) 详细设计与实现 4) 系统测试方法 2、 运行画面截图 3、 每一部分附上关键性代码 4、 项目总结 四、 预习与参考 1、 教材有关决策树、SVM(支持向量机)、随机森林、XGBoostLightGBM有关章节。 2、 课程PPT有关内容。 3、 中国知网有关手写数字识别的文献资料。 五、 考核形式 根据提交的设计文档完成程度以及程序功能的实现情况(要求演示)进行考核:  无任何文档,无程序,得 0 分;  文档描述不清楚,思路混乱,程序不能运行,2分;  文档描述清晰,程序实现了基本功能,3.5分;  文档描述清晰准确,思路清晰,程序实现了要求的所有功能,4. 5分;  文档完备,设计合理有创新,报告清晰明确,深入分析了自己进行实验的体会感想,程序实现了全部功能,功能完善,并有其它的创新实现,5分。 六、 实验报告要求 1、 实验目的结合自己个人的实际情况书写,不要雷同。 2、 项目概要设计说明书(描述软件系统架构、逻辑架构、物理架构、部署结构、功能架构及关键技术,关键业务模块需通过UML图进行详细描述)、需求规格说明书(包括功能设计、非功能性设计、系统用例)。 3、 项目设计运行截图。 代码程序: # ======================== # 导入必要的库 # ======================== import numpy as np # 数值计算库 import matplotlib.pyplot as plt # 绘图库 import pandas as pd # 数据处理库 import tkinter as tk # GUI库 from tkinter import ttk, filedialog, messagebox # GUI组件 from PIL import Image, ImageDraw # 图像处理库 import cv2 # 计算机视觉库 import os # 操作系统接口 import csv # CSV文件处理 from sklearn.datasets import load_digits # 加载数字数据集 from sklearn.model_selection import train_test_split # 数据集划分 from sklearn.svm import SVC # 支持向量机模型 from sklearn.tree import DecisionTreeClassifier # 决策树模型 from sklearn.ensemble import RandomForestClassifier # 随机森林模型 from sklearn.neural_network import MLPClassifier # 多层感知机模型 from sklearn.neighbors import KNeighborsClassifier # K近邻模型 from sklearn.naive_bayes import GaussianNB # 朴素贝叶斯模型 from sklearn.metrics import accuracy_score # 准确率评估 from sklearn.preprocessing import StandardScaler # 数据标准化 # 设置中文字体负号显示(解决中文乱码问题) plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"] plt.rcParams["axes.unicode_minus"] = False # ======================== # 尝试导入可选模型库 # ======================== XGB_INSTALLED = False # 标记XGBoost是否安装 LGB_INSTALLED = False # 标记LightGBM是否安装 try: import xgboost as xgb # XGBoost模型 XGB_INSTALLED = True except ImportError: print("警告: 未安装XGBoost库,无法使用XGBoost模型") try: import lightgbm as lgb # LightGBM模型 LGB_INSTALLED = True except ImportError: print("警告: 未安装LightGBM库,无法使用LightGBM模型") # ======================== # 模型配置 # ======================== # 定义模型元数据(包含模型名称、类、标准化器参数) MODEL_METADATA = { &#39;svm&#39;: (&#39;支持向量机(SVM)&#39;, SVC, StandardScaler, {&#39;probability&#39;: True, &#39;random_state&#39;: 42}), &#39;dt&#39;: (&#39;决策树(DT)&#39;, DecisionTreeClassifier, None, {&#39;random_state&#39;: 42}), &#39;rf&#39;: (&#39;随机森林(RF)&#39;, RandomForestClassifier, None, {&#39;n_estimators&#39;: 100, &#39;random_state&#39;: 42}), &#39;mlp&#39;: (&#39;多层感知机(MLP)&#39;, MLPClassifier, StandardScaler, {&#39;hidden_layer_sizes&#39;: (100, 50), &#39;max_iter&#39;: 500, &#39;random_state&#39;: 42}), &#39;knn&#39;: (&#39;K最近邻(KNN)&#39;, KNeighborsClassifier, StandardScaler, {&#39;n_neighbors&#39;: 5, &#39;weights&#39;: &#39;distance&#39;}), &#39;nb&#39;: (&#39;高斯朴素贝叶斯(NB)&#39;, GaussianNB, None, {}), } # 添加可选模型(如果已安装) if XGB_INSTALLED: MODEL_METADATA[&#39;xgb&#39;] = (&#39;XGBoost(XGB)&#39;, xgb.XGBClassifier, None, {&#39;objective&#39;: &#39;multi:softmax&#39;, &#39;random_state&#39;: 42}) if LGB_INSTALLED: MODEL_METADATA[&#39;lgb&#39;] = (&#39;LightGBM(LGB)&#39;, lgb.LGBMClassifier, None, { &#39;objective&#39;: &#39;multiclass&#39;, &#39;random_state&#39;: 42, &#39;num_class&#39;: 10, &#39;max_depth&#39;: 5, &#39;min_child_samples&#39;: 10, &#39;learning_rate&#39;: 0.1, &#39;force_col_wise&#39;: True }) # ======================== # 模型工厂类 - 负责模型创建、训练评估 # ======================== class ModelFactory: @staticmethod def get_split_data(digits_dataset): """数据集划分""" X, y = digits_dataset.data, digits_dataset.target # 获取特征标签 # 划分训练集测试集(70%训练,30%测试) return train_test_split(X, y, test_size=0.3, random_state=42) @classmethod def create_model(cls, model_type): """创建模型数据标准化器""" # 检查模型类型是否有效 if model_type not in MODEL_METADATA: raise ValueError(f"未知模型类型: {model_type}") # 从配置中获取模型信息 name, model_cls, scaler_cls, params = MODEL_METADATA[model_type] # 创建模型实例标准化器 model = model_cls(**params) scaler = scaler_cls() if scaler_cls else None return model, scaler @staticmethod def train_model(model, X_train, y_train, scaler=None, model_type=None): """训练模型""" # 数据标准化处理 if scaler: X_train = scaler.fit_transform(X_train) # LightGBM特殊处理(需要DataFrame格式) if model_type == &#39;lgb&#39; and isinstance(X_train, np.ndarray): X_train = pd.DataFrame(X_train) # 训练模型 model.fit(X_train, y_train) return model @staticmethod def evaluate_model(model, X_test, y_test, scaler=None, model_type=None): """评估模型""" # 数据标准化处理 if scaler: X_test = scaler.transform(X_test) # LightGBM特殊处理 if model_type == &#39;lgb&#39; and isinstance(X_test, np.ndarray) and hasattr(model, &#39;feature_name_&#39;): X_test = pd.DataFrame(X_test, columns=model.feature_name_) # 预测并计算准确率 y_pred = model.predict(X_test) return accuracy_score(y_test, y_pred) @classmethod def train_and_evaluate(cls, model_type, X_train, y_train, X_test, y_test): """训练并评估模型""" try: # 创建模型 model, scaler = cls.create_model(model_type) # 训练模型 model = cls.train_model(model, X_train, y_train, scaler, model_type) # 评估模型 accuracy = cls.evaluate_model(model, X_test, y_test, scaler, model_type) return model, scaler, accuracy except Exception as e: print(f"模型 {model_type} 训练/评估错误: {str(e)}") raise @classmethod def evaluate_all_models(cls, digits_dataset): """评估所有可用模型""" print("\n=== 模型评估 ===") # 划分数据集 X_train, X_test, y_train, y_test = cls.get_split_data(digits_dataset) results = [] # 存储结果 # 遍历所有模型 for model_type in MODEL_METADATA: name = MODEL_METADATA[model_type][0] print(f"评估模型: {name} ({model_type})") # 检查模型是否可用 if not MODEL_METADATA[model_type][1]: results.append({"模型名称": name, "准确率": "N/A"}) continue try: # 训练并评估模型 _, _, accuracy = cls.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) results.append({"模型名称": name, "准确率": f"{accuracy:.4f}"}) except Exception as e: results.append({"模型名称": name, "准确率": f"错误: {str(e)}"}) # 按准确率排序 results.sort( key=lambda x: float(x["准确率"]) if isinstance(x["准确率"], str) and x["准确率"].replace(&#39;.&#39;, &#39;&#39;, 1).isdigit() else -1, reverse=True ) # 打印结果 print(pd.DataFrame(results)) return results # ======================== # 手写板类 - GUI界面绘图功能 # ======================== class HandwritingBoard: CANVAS_SIZE = 300 # 固定画布尺寸 BRUSH_SIZE = 12 # 画笔大小 def __init__(self, root, model_factory, digits): # 初始化主窗口 self.root = root self.root.title("手写数字识别系统") self.root.geometry("1000x700") # 设置窗口大小 # 模型数据相关 self.model_factory = model_factory # 模型工厂 self.digits = digits # 数字数据集 self.model_cache = {} # 模型缓存(提高切换速度) self.current_model = None # 当前使用的模型 self.scaler = None # 数据标准化器 self.current_model_type = None # 当前模型类型 self.has_drawn = False # 标记是否已绘制数字 self.custom_data = [] # 存储自定义训练数据 # 绘图相关状态 self.drawing = False # 是否正在绘制 self.last_x = self.last_y = 0 # 上次绘制位置 # 创建自定义数据目录 self.data_dir = "custom_digits_data" os.makedirs(self.data_dir, exist_ok=True) # 初始化画布(PIL图像) self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) # 创建白色背景图像 self.draw_obj = ImageDraw.Draw(self.image) # 创建绘图对象 # 创建界面组件 self.create_widgets() # 初始化默认模型 self.init_default_model() def create_widgets(self): """创建界面组件""" # 创建主框架 main_frame = tk.Frame(self.root) main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) # 1. 模型选择区域 model_frame = tk.LabelFrame(main_frame, text="模型选择", font=("Arial", 10, "bold")) model_frame.grid(row=0, column=0, columnspan=2, sticky="ew", padx=5, pady=5) model_frame.grid_columnconfigure(1, weight=1) # 让模型标签可以扩展 # 模型选择标签 tk.Label(model_frame, text="选择模型:", font=("Arial", 10)).grid(row=0, column=0, padx=5, pady=5, sticky="w") # 获取可用模型列表 self.available_models = [] for model_type, (name, _, _, _) in MODEL_METADATA.items(): if MODEL_METADATA[model_type][1]: self.available_models.append((model_type, name)) # 模型选择下拉框 self.model_var = tk.StringVar() self.model_combobox = ttk.Combobox( model_frame, textvariable=self.model_var, values=[name for _, name in self.available_models], state="readonly", width=25, font=("Arial", 10) ) self.model_combobox.current(0) # 设置默认选项 self.model_combobox.bind("<<ComboboxSelected>>", self.on_model_select) # 绑定选择事件 self.model_combobox.grid(row=0, column=1, padx=5, pady=5, sticky="ew") # 模型信息标签(显示准确率) self.model_label = tk.Label( model_frame, text="", font=("Arial", 10), relief=tk.SUNKEN, padx=5, pady=2 ) self.model_label.grid(row=0, column=2, padx=5, pady=5, sticky="ew") # 2. 左侧绘图区域右侧结果区域 # 左侧绘图区域 left_frame = tk.LabelFrame(main_frame, text="绘制区域", font=("Arial", 10, "bold")) left_frame.grid(row=1, column=0, padx=5, pady=5, sticky="nsew") # 绘图画布 self.canvas = tk.Canvas(left_frame, bg="white", width=self.CANVAS_SIZE, height=self.CANVAS_SIZE) self.canvas.pack(padx=10, pady=10) # 绑定绘图事件 self.canvas.bind("<Button-1>", self.start_draw) # 鼠标按下 self.canvas.bind("<B1-Motion>", self.draw) # 鼠标拖动 self.canvas.bind("<ButtonRelease-1>", self.stop_draw) # 鼠标释放 # 添加绘制提示 self.canvas.create_text( self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2, text="绘制数字", fill="gray", font=("Arial", 16) ) # 绘图控制按钮 btn_frame = tk.Frame(left_frame) btn_frame.pack(fill=tk.X, pady=(0, 10)) # 功能按钮 tk.Button(btn_frame, text="识别", command=self.recognize, width=8).pack(side=tk.LEFT, padx=5) tk.Button(btn_frame, text="清除", command=self.clear_canvas, width=8).pack(side=tk.LEFT, padx=5) tk.Button(btn_frame, text="样本", command=self.show_samples, width=8).pack(side=tk.LEFT, padx=5) # 右侧结果区域 right_frame = tk.Frame(main_frame) right_frame.grid(row=1, column=1, padx=5, pady=5, sticky="nsew") # 2.1 识别结果区域 result_frame = tk.LabelFrame(right_frame, text="识别结果", font=("Arial", 10, "bold")) result_frame.pack(fill=tk.X, padx=5, pady=5) # 结果显示标签 self.result_label = tk.Label( result_frame, text="请绘制数字", font=("Arial", 24), pady=10 ) self.result_label.pack() # 置信度显示标签 self.prob_label = tk.Label( result_frame, text="", font=("Arial", 12) ) self.prob_label.pack() # 2.2 置信度可视化区域 confidence_frame = tk.LabelFrame(right_frame, text="识别置信度", font=("Arial", 10, "bold")) confidence_frame.pack(fill=tk.X, padx=5, pady=5) # 置信度画布(条形图) self.confidence_canvas = tk.Canvas( confidence_frame, bg="white", height=50 ) self.confidence_canvas.pack(fill=tk.X, padx=10, pady=10) self.confidence_canvas.create_text( 150, 25, text="识别后显示置信度", fill="gray", font=("Arial", 10) ) # 2.3 候选数字区域 candidates_frame = tk.LabelFrame(right_frame, text="可能的数字", font=("Arial", 10, "bold")) candidates_frame.pack(fill=tk.X, padx=5, pady=5) # 候选数字表格 columns = ("数字", "概率") self.candidates_tree = ttk.Treeview( candidates_frame, columns=columns, show="headings", height=4 ) # 配置表格列 for col in columns: self.candidates_tree.heading(col, text=col) self.candidates_tree.column(col, width=80, anchor=tk.CENTER) # 添加滚动条 scrollbar = ttk.Scrollbar( candidates_frame, orient=tk.VERTICAL, command=self.candidates_tree.yview ) self.candidates_tree.configure(yscroll=scrollbar.set) self.candidates_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5) scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5) # 3. 模型性能对比训练集管理区域 # 3.1 模型性能对比区域 performance_frame = tk.LabelFrame(main_frame, text="模型性能对比", font=("Arial", 10, "bold")) performance_frame.grid(row=2, column=0, padx=5, pady=5, sticky="nsew") # 性能表格 columns = ("模型名称", "准确率") self.performance_tree = ttk.Treeview( performance_frame, columns=columns, show="headings", height=8 ) # 配置表格列 for col in columns: self.performance_tree.heading(col, text=col) self.performance_tree.column(col, width=120, anchor=tk.CENTER) # 添加滚动条 scrollbar = ttk.Scrollbar( performance_frame, orient=tk.VERTICAL, command=self.performance_tree.yview ) self.performance_tree.configure(yscroll=scrollbar.set) self.performance_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5) scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5) # 3.2 训练集管理区域 train_frame = tk.LabelFrame(main_frame, text="训练集管理", font=("Arial", 10, "bold")) train_frame.grid(row=2, column=1, padx=5, pady=5, sticky="nsew") # 训练集管理按钮 tk.Button( train_frame, text="保存为训练样本", command=self.save_as_training_sample, width=18, height=2 ).grid(row=0, column=0, padx=5, pady=5, sticky="ew") tk.Button( train_frame, text="保存全部训练集", command=self.save_all_training_data, width=18, height=2 ).grid(row=0, column=1, padx=5, pady=5, sticky="ew") tk.Button( train_frame, text="加载训练集", command=self.load_training_data, width=18, height=2 ).grid(row=1, column=0, padx=5, pady=5, sticky="ew") tk.Button( train_frame, text="性能图表", command=self.show_performance_chart, width=18, height=2 ).grid(row=1, column=1, padx=5, pady=5, sticky="ew") # 4. 状态栏 self.status_var = tk.StringVar(value="就绪") status_bar = tk.Label( self.root, textvariable=self.status_var, bd=1, relief=tk.SUNKEN, anchor=tk.W, font=("Arial", 10) ) status_bar.pack(side=tk.BOTTOM, fill=tk.X) # 配置布局权重 main_frame.grid_columnconfigure(0, weight=1) main_frame.grid_columnconfigure(1, weight=1) main_frame.grid_rowconfigure(1, weight=1) main_frame.grid_rowconfigure(2, weight=1) # ======================== # 绘图功能 # ======================== def start_draw(self, event): """开始绘制""" self.drawing = True self.last_x, self.last_y = event.x, event.y def draw(self, event): """绘制""" if not self.drawing: return x, y = event.x, event.y # 在画布上绘制 self.canvas.create_line( self.last_x, self.last_y, x, y, fill="black", width=self.BRUSH_SIZE, capstyle=tk.ROUND, smooth=True ) # 在图像上绘制(用于后续处理) self.draw_obj.line( [self.last_x, self.last_y, x, y], fill=0, # 黑色 width=self.BRUSH_SIZE ) # 更新位置 self.last_x, self.last_y = x, y def stop_draw(self, event): """停止绘制""" self.drawing = False self.has_drawn = True self.status_var.set("已绘制数字,点击&#39;识别&#39;进行识别") def clear_canvas(self): """清除画布""" # 清除画布内容 self.canvas.delete("all") # 重置图像 self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255) # 白色背景 self.draw_obj = ImageDraw.Draw(self.image) # 添加绘制提示 self.canvas.create_text( self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2, text="绘制数字", fill="gray", font=("Arial", 16) ) # 重置结果显示 self.result_label.config(text="请绘制数字") self.prob_label.config(text="") self.clear_confidence_display() self.has_drawn = False self.status_var.set("画布已清除") def clear_confidence_display(self): """清除置信度显示""" self.confidence_canvas.delete("all") self.confidence_canvas.create_text( 150, 25, text="识别后显示置信度", fill="gray", font=("Arial", 10) ) # 清空候选数字表格 for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) # ======================== # 图像处理功能 # ======================== def preprocess_image(self): """预处理手写数字图像""" # 将PIL图像转换为NumPy数组 img_array = np.array(self.image) # 1. 高斯模糊降噪 img_array = cv2.GaussianBlur(img_array, (5, 5), 0) # 2. 二值化(转换为黑白图像) _, img_array = cv2.threshold(img_array, 127, 255, cv2.THRESH_BINARY_INV) # 3. 轮廓检测(查找数字轮廓) contours, _ = cv2.findContours(img_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: self.status_var.set("未检测到有效数字,请重新绘制") return None # 4. 找到最大轮廓(即数字部分) c = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(c) # 5. 提取数字区域 digit = img_array[y:y+h, x:x+w] # 6. 填充为正方形(保持长宽比) size = max(w, h) padded = np.ones((size, size), dtype=np.uint8) * 255 # 白色背景 offset_x = (size - w) // 2 offset_y = (size - h) // 2 padded[offset_y:offset_y+h, offset_x:offset_x+w] = digit # 7. 缩放为8x8(匹配MNIST数据集格式) resized = cv2.resize(padded, (8, 8), interpolation=cv2.INTER_AREA) # 8. 归一化(将像素值从0-255映射到0-16) normalized = 16 - (resized / 255 * 16).astype(np.uint8) # 9. 展平为一维数组(64个特征) return normalized.flatten() # ======================== # 识别功能 # ======================== def recognize(self): """识别手写数字""" # 检查是否已绘制数字 if not self.has_drawn: self.status_var.set("请先绘制数字再识别") return # 检查模型是否已加载 if self.current_model is None: self.status_var.set("模型未加载,请选择模型") return # 预处理图像 img_array = self.preprocess_image() if img_array is None: return # 重塑为模型输入格式(1个样本,64个特征) img_input = img_array.reshape(1, -1) try: # 数据标准化 if self.scaler: img_input = self.scaler.transform(img_input) # LightGBM特殊处理(需要DataFrame格式) if self.current_model_type == &#39;lgb&#39; and hasattr(self.current_model, &#39;feature_name_&#39;): img_input = pd.DataFrame(img_input, columns=self.current_model.feature_name_) # 预测数字 pred = self.current_model.predict(img_input)[0] self.result_label.config(text=f"识别结果: {pred}") # 概率预测(如果模型支持) if hasattr(self.current_model, &#39;predict_proba&#39;): probs = self.current_model.predict_proba(img_input)[0] confidence = probs[pred] # 预测结果的置信度 # 更新UI self.prob_label.config(text=f"置信度: {confidence:.2%}") self.update_confidence_display(confidence) # 更新置信度可视化 # 显示候选数字(概率最高的3个) top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3] self.update_candidates_display(top3) else: self.prob_label.config(text="该模型不支持概率输出") self.clear_confidence_display() self.status_var.set(f"识别完成: 数字 {pred}") except Exception as e: self.status_var.set(f"识别错误: {str(e)}") self.clear_confidence_display() # ======================== # UI更新功能 # ======================== def update_confidence_display(self, confidence): """更新置信度可视化""" self.confidence_canvas.delete("all") # 获取画布宽度 canvas_width = self.confidence_canvas.winfo_width() or 300 # 绘制背景 self.confidence_canvas.create_rectangle( 10, 10, canvas_width - 10, 40, fill="#f0f0f0", outline="#cccccc" ) # 绘制置信度条(根据置信度值) bar_width = int((canvas_width - 20) * confidence) color = self.get_confidence_color(confidence) # 根据置信度选择颜色 self.confidence_canvas.create_rectangle( 10, 10, 10 + bar_width, 40, fill=color, outline="" ) # 绘制文本(显示百分比) self.confidence_canvas.create_text( canvas_width / 2, 25, text=f"{confidence:.1%}", font=("Arial", 10, "bold") ) # 绘制刻度 for i in range(0, 11): x_pos = 10 + i * (canvas_width - 20) / 10 self.confidence_canvas.create_line(x_pos, 40, x_pos, 45, width=1) if i % 2 == 0: self.confidence_canvas.create_text(x_pos, 55, text=f"{i*10}%", font=("Arial", 8)) def get_confidence_color(self, confidence): """根据置信度获取颜色""" # 高置信度:绿色 if confidence >= 0.9: return "#4CAF50" # 中等置信度:黄色 elif confidence >= 0.7: return "#FFC107" # 低置信度:红色 else: return "#F44336" def update_candidates_display(self, candidates): """更新候选数字显示""" # 清空现有项 for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) # 添加新项(候选数字及其概率) for digit, prob in candidates: self.candidates_tree.insert( "", tk.END, values=(digit, f"{prob:.2%}") ) # ======================== # 样本显示功能 # ======================== def show_samples(self): """显示样本图像""" plt.figure(figsize=(10, 4)) # 显示0-9每个数字的一个样本 for i in range(10): plt.subplot(2, 5, i+1) sample_idx = np.where(self.digits.target == i)[0][0] plt.imshow(self.digits.images[sample_idx], cmap="gray") plt.title(f"数字 {i}", fontsize=9) plt.axis("off") plt.tight_layout() plt.show() # ======================== # 模型管理功能 # ======================== def on_model_select(self, event): """模型选择事件处理""" # 获取选中的模型名称 selected_name = self.model_var.get() # 查找对应的模型类型 model_type = next( (k for k, v in self.available_models if v == selected_name), None ) if model_type: # 切换模型 self.change_model(model_type) def change_model(self, model_type): """切换模型""" model_name = MODEL_METADATA[model_type][0] # 尝试从缓存加载模型 if model_type in self.model_cache: self.current_model, self.scaler, accuracy, self.current_model_type = self.model_cache[model_type] self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})") self.status_var.set(f"已加载模型: {model_name}") return # 加载新模型 self.status_var.set(f"正在加载模型: {model_name}...") self.root.update() # 更新UI显示状态 try: # 获取数据集 X_train, X_test, y_train, y_test = self.model_factory.get_split_data(self.digits) # 训练并评估模型 self.current_model, self.scaler, accuracy = self.model_factory.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) # 缓存模型 self.current_model_type = model_type self.model_cache[model_type] = (self.current_model, self.scaler, accuracy, self.current_model_type) # 更新UI self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})") self.status_var.set(f"模型加载完成: {model_name}, 准确率: {accuracy:.4f}") self.clear_canvas() # 更新性能表格 self.load_performance_data() except Exception as e: self.status_var.set(f"模型加载失败: {str(e)}") self.model_label.config(text="模型加载失败") def init_default_model(self): """初始化默认模型""" # 设置默认模型并加载 self.model_var.set(self.available_models[0][1]) self.change_model(self.available_models[0][0]) def load_performance_data(self): """加载性能数据""" # 评估所有模型 results = self.model_factory.evaluate_all_models(self.digits) # 清空表格 for item in self.performance_tree.get_children(): self.performance_tree.delete(item) # 添加数据到表格 for i, result in enumerate(results): tag = "highlight" if i == 0 else "" # 高亮显示性能最好的模型 self.performance_tree.insert( "", tk.END, values=(result["模型名称"], result["准确率"]), tags=(tag,) ) # 配置高亮样式 self.performance_tree.tag_configure("highlight", background="#e6f7ff") # ======================== # 性能可视化功能 # ======================== def show_performance_chart(self): """显示性能图表""" # 获取性能数据 results = self.model_factory.evaluate_all_models(self.digits) # 提取有效结果(过滤掉错误数据) valid_results = [] for result in results: try: accuracy = float(result["准确率"]) valid_results.append((result["模型名称"], accuracy)) except ValueError: continue if not valid_results: messagebox.showinfo("提示", "没有可用的性能数据") return # 按准确率排序 valid_results.sort(key=lambda x: x[1], reverse=True) models, accuracies = zip(*valid_results) # 创建水平条形图 plt.figure(figsize=(10, 5)) bars = plt.barh(models, accuracies, color=&#39;#2196F3&#39;) plt.xlabel(&#39;准确率&#39;, fontsize=10) plt.ylabel(&#39;模型&#39;, fontsize=10) plt.title(&#39;模型性能对比&#39;, fontsize=12) plt.xlim(0, 1.05) # 设置X轴范围 # 添加数值标签 for bar in bars: width = bar.get_width() plt.text( width + 0.01, bar.get_y() + bar.get_height()/2, f&#39;{width:.4f}&#39;, ha=&#39;left&#39;, va=&#39;center&#39;, fontsize=8 ) plt.tight_layout() plt.show() # ======================== # 训练集管理功能 # ======================== def save_as_training_sample(self): """保存为训练样本""" # 检查是否已绘制数字 if not self.has_drawn: self.status_var.set("请先绘制数字再保存") return # 预处理图像 img_array = self.preprocess_image() if img_array is None: return # 弹出标签输入窗口 label_window = tk.Toplevel(self.root) label_window.title("输入标签") label_window.geometry("300x150") label_window.transient(self.root) label_window.grab_set() # 模态窗口 # 标签输入提示 tk.Label( label_window, text="请输入数字标签 (0-9):", font=("Arial", 10) ).pack(pady=10) # 输入框 entry = tk.Entry(label_window, font=("Arial", 12), width=5) entry.pack(pady=5) entry.focus_set() def save_with_label(): """保存带标签的样本""" try: # 验证标签 label = int(entry.get()) if label < 0 or label > 9: raise ValueError("标签必须是0-9的数字") # 添加到自定义数据集 self.custom_data.append((img_array.tolist(), label)) self.status_var.set(f"已保存数字 {label} (共 {len(self.custom_data)} 个样本)") label_window.destroy() except ValueError as e: self.status_var.set(f"保存错误: {str(e)}") # 保存按钮 tk.Button( label_window, text="保存", command=save_with_label, width=10 ).pack(pady=5) def save_all_training_data(self): """保存全部训练数据""" # 检查是否有数据可保存 if not self.custom_data: self.status_var.set("没有训练数据可保存") return # 弹出文件保存对话框 file_path = filedialog.asksaveasfilename( defaultextension=".csv", filetypes=[("CSV文件", "*.csv")], initialfile="custom_digits.csv", title="保存训练集" ) if not file_path: return try: # 写入CSV文件 with open(file_path, &#39;w&#39;, newline=&#39;&#39;, encoding=&#39;utf-8&#39;) as f: writer = csv.writer(f) # 写入表头(64个像素+标签) writer.writerow([f&#39;pixel{i}&#39; for i in range(64)] + [&#39;label&#39;]) # 写入数据 for img_data, label in self.custom_data: writer.writerow(img_data + [label]) self.status_var.set(f"已保存 {len(self.custom_data)} 个样本到 {os.path.basename(file_path)}") except Exception as e: self.status_var.set(f"保存失败: {str(e)}") def load_training_data(self): """加载训练数据""" # 弹出文件选择对话框 file_path = filedialog.askopenfilename( filetypes=[("CSV文件", "*.csv")], title="加载训练集" ) if not file_path: return try: self.custom_data = [] # 读取CSV文件 with open(file_path, &#39;r&#39;, newline=&#39;&#39;, encoding=&#39;utf-8&#39;) as f: reader = csv.reader(f) next(reader) # 跳过标题行 # 解析每一行数据 for row in reader: if len(row) != 65: # 64像素+1标签 continue # 提取像素数据标签 img_data = [float(pixel) for pixel in row[:64]] label = int(row[64]) self.custom_data.append((img_data, label)) self.status_var.set(f"已加载 {len(self.custom_data)} 个样本") except Exception as e: self.status_var.set(f"加载失败: {str(e)}") # ======================== # 主程序入口 # ======================== def run(self): """运行应用""" self.root.mainloop() # ======================== # 程序入口 # ======================== if __name__ == "__main__": digits = load_digits() # 加载数字数据集 root = tk.Tk() # 创建主窗口 app = HandwritingBoard(root, ModelFactory, digits) # 创建应用实例 app.run() # 运行应用 请你根据上面的内容生成符合要求的课程设计:
最新发布
06-24
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值