<input type="file"/>控件是否为空验证,文件类型验证,file控件清空,禁止手写。

本文介绍了一个HTML文件上传表单的设计与实现,包括如何禁止手动输入、验证文件是否为空、检查文件类型是否符合要求(仅限.xls文件),以及如何在上传不符合条件的文件时清空file控件。

file控件禁止手写,是否为空验证,文件类型验证,file控件清空。

<html>
<body bgcolor="white">
	<TABLE cellSpacing=0 cellPadding=0 width="100%" border=0>
		<TBODY>
			<TR>
				<TD>
					<center>
						<h4>
							考勤记录导入
						</h4>
					</center>
					<hr/>
					<br/>
					<center>
					<form action="<%=request.getContextPath() %>/admin/HtmlFile.do" method="post" enctype="multipart/form-data" name="HtmlFileForm" onsubmit="return check()">
						<span id='uploadSpan'>
						<input type="file" name="file" id="file" unselectable="on" onchange='checkFileType(this.value);'/>&nbsp;
						</span>
					<input type="submit" value="提交"/>
					</form>
					</center>
				</TD>
			</TR>
		</TBODY>
	</TABLE>
<script type="text/javascript">
function checkFileType(str){
	var pos = str.lastIndexOf(".");
	var lastname = str.substring(pos,str.length);
	var resultName=lastname.toLowerCase();
	if ('.xls'!=resultName.toString()){alert('只能上传xls文件,您上传的文件类型为'+lastname+',请重新上传');
	resetFile();
	}
}
var html=document.getElementById('uploadSpan').innerHTML; 

function resetFile(){ 
document.getElementById('uploadSpan').innerHTML=html; 
} 

function   check() 
{ 
   if(HtmlFileForm.file.value == "") {
           alert("请选择文件路径");
           return false;
   }else{
   		return true;
   }
}
	
</script>
</body>

</html>
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <title>Title</title> </head> <script src="tf.min.js"></script> <style> #canvas { border: 2px solid #333; background: black; cursor: crosshair; width: 280px; height: 280px; } #result { font-size: 24px; margin: 15px 0; min-height: 30px; } button { padding: 8px 15px; margin: 5px; background: #4285f4; color: white; border: none; border-radius: 4px; } body { display: flex; justify-content: center; /* 水平居中 */ align-items: center; /* 垂直居中 */ min-height: 100vh; /* 确保容器高度撑满视口 */ margin: 0; /* 清除默认边距 */ flex-direction: column; /* 子元素纵向排列(可选) */ } </style> <body> <h1>MNIST手写数字识别</h1> <canvas id="canvas" width="280" height="280"></canvas> <div> <button id="clear">清除</button> <button id="predict">识别</button> </div> <div id="result">请绘制0-9的数字</div> <script> // 模型文件路径 const MODEL_URL = 'model/model.json'; // 模型文件 // 声明变量:模型、画布、画布上下文、是否正在绘制 let model, canvas, ctx, isDrawing = false; // 上一次绘制的坐标 let lastX = 0, lastY = 0; // 异步加载模型 async function loadModel() { try { model = await tf.loadGraphModel(MODEL_URL); console.log('模型加载成功'); } catch (error) { console.error('模型加载失败:', error); document.getElementById('result').innerHTML = '模型加载失败'; } } // 初始化画布 function initCanvas() { canvas = document.getElementById('canvas'); // 获取画布元素 ctx = canvas.getContext('2d'); // 获取画布上下文 ctx.fillStyle = 'black'; // 设置填充颜色为黑色 ctx.fillRect(0, 0, canvas.width, canvas.height); // 用黑色填充画布 ctx.lineWidth = 15; // 设置线条宽度 ctx.lineCap = 'round'; // 设置线条端点为圆形 ctx.strokeStyle = 'white'; // 设置线条颜色为白色 // 为画布添加事件监听器,处理绘制操作 canvas.addEventListener('mousedown', startDrawing); canvas.addEventListener('mousemove', draw); canvas.addEventListener('mouseup', stopDrawing); canvas.addEventListener('mouseout', stopDrawing); } function startDrawing(e) { isDrawing = true; [lastX, lastY] = [e.offsetX, e.offsetY]; // 记录起始坐标 } function draw(e) { if (!isDrawing) return; // 如果不在绘制状态,则不执行绘制操作 ctx.beginPath(); // 开始新的路径 ctx.moveTo(lastX, lastY); // 将画笔移动到上一次坐标 ctx.lineTo(e.offsetX, e.offsetY); // 画线到当前坐标 ctx.stroke(); // 描边 [lastX, lastY] = [e.offsetX, e.offsetY]; // 更新上一次坐标 } function stopDrawing() { isDrawing = false; } // 预处理画布内容,转换为模型输入张量 function preprocessCanvas() { // 清理内存,释放除retnet外的所有张量 return tf.tidy(() => { const tempCanvas = document.createElement('canvas'); // 创建临时画布 tempCanvas.width = 28; tempCanvas.height = 28; // 设置临时画布尺寸 const tempCtx = tempCanvas.getContext('2d'); // 获取临时画布上下文 tempCtx.drawImage(canvas, 0, 0, 28, 28); // 将原始画布内容绘制到临时画布上,并调整尺寸 // 将临时画布转换为张量,并进行预处理:灰度化、归一化、调整形状 return tf.browser.fromPixels(tempCanvas, 1) .toFloat() .div(255) .reshape([1, 28, 28]); }); } async function predict() { try { const input = preprocessCanvas(); // 预处理画布内容 console.log(input); const pred = model.predict(input); // 使用模型进行预测 const result = pred.argMax(1).dataSync()[0]; // 获取预测结果(最大值的索引) document.getElementById('result').innerHTML = `识别结果: ${result}`; tf.dispose([input, pred]); // 释放张量内存 } catch (error) { console.error('预测错误:', error); document.getElementById('result').innerHTML = '预测失败'; } } function clearCanvas() { ctx.fillStyle = 'black'; // 设置填充颜色为黑色 ctx.fillRect(0, 0, canvas.width, canvas.height); // 用黑色填充整个画布 document.getElementById('result').innerHTML = '请绘制0-9的数字'; // 清空结果区域的内容 } window.onload = async () => { initCanvas(); // 初始化画布 await loadModel(); // 加载模型 document.getElementById('predict').addEventListener('click', predict); // 为预测按钮添加点击事件监听器 document.getElementById('clear').addEventListener('click', clearCanvas); // 为清除按钮添加点击事件监听器 }; </script> </body> </html> 模仿这个写上传图片识别保定美食
06-27
#include <QApplication> #include <QMainWindow> #include <QWidget> #include <QPushButton> #include <QLabel> #include <QSlider> #include <QSpinBox> #include <QComboBox> #include <QGroupBox> #include <QHBoxLayout> #include <QGridLayout> #include <QPainter> #include <QMouseEvent> #include <QFileDialog> #include <QMessageBox> #include <QTimer> #include <QPixmap> #include <QPen> #include <QColor> #include <QDebug> #include <QProgressBar> #include <QProgressDialog> #include <QTime> #include <QElapsedTimer> #include <iostream> #include <vector> #include <cmath> #include <random> #include <fstream> #include <algorithm> #include <string> using namespace std; // 反转字节顺序(MNIST数据使用大端字节序,我们需要转换为小端字节序) int reverseInt(int i) { unsigned char c1, c2, c3, c4; c1 = i & 255; c2 = (i >> 8) & 255; c3 = (i >> 16) & 255; c4 = (i >> 24) & 255; return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4; } // 读取MNIST图像文件 vector<vector<double>> readMNISTImages(const string& filename, int& numberOfImages, int& imageSize) { ifstream file(filename, ios::binary); vector<vector<double>> images; if (file.is_open()) { int magicNumber = 0; int nRows = 0, nCols = 0; file.read((char*)&magicNumber, sizeof(magicNumber)); magicNumber = reverseInt(magicNumber); file.read((char*)&numberOfImages, sizeof(numberOfImages)); numberOfImages = reverseInt(numberOfImages); file.read((char*)&nRows, sizeof(nRows)); nRows = reverseInt(nRows); file.read((char*)&nCols, sizeof(nCols)); nCols = reverseInt(nCols); imageSize = nRows * nCols; cout << "正在加载图像文件: " << filename << endl; cout << "图像数量: " << numberOfImages << ", 图像尺寸: " << nRows << "x" << nCols << endl; for (int i = 0; i < numberOfImages; ++i) { if (i % 10000 == 0) { cout << "已加载 " << i << " 张图像..." << endl; } vector<double> image(imageSize); for (int j = 0; j < imageSize; ++j) { unsigned char pixel = 0; file.read((char*)&pixel, sizeof(pixel)); image[j] = static_cast<double>(pixel) / 255.0; // 归一化到[0,1] } images.push_back(image); } file.close(); cout << "图像加载完成!" << endl; } else { cerr << "无法打开文件: " << filename << endl; } return images; } // 读取MNIST标签文件 vector<int> readMNISTLabels(const string& filename, int& numberOfLabels) { ifstream file(filename, ios::binary); vector<int> labels; if (file.is_open()) { int magicNumber = 0; file.read((char*)&magicNumber, sizeof(magicNumber)); magicNumber = reverseInt(magicNumber); file.read((char*)&numberOfLabels, sizeof(numberOfLabels)); numberOfLabels = reverseInt(numberOfLabels); cout << "正在加载标签文件: " << filename << endl; cout << "标签数量: " << numberOfLabels << endl; for (int i = 0; i < numberOfLabels; ++i) { if (i % 10000 == 0) { cout << "已加载 " << i << " 个标签..." << endl; } unsigned char label = 0; file.read((char*)&label, sizeof(label)); labels.push_back(static_cast<int>(label)); } file.close(); cout << "标签加载完成!" << endl; } else { cerr << "无法打开文件: " << filename << endl; } return labels; } // 图像处理类 class ImageProcessor { public: // 中值滤波去噪 static vector<double> medianFilter(const vector<double>& input, int width, int height, int kernelSize = 3) { if (kernelSize % 2 == 0) kernelSize++; // 确保核大小为奇数 vector<double> output(input.size()); int padding = kernelSize / 2; for (int y = 0; y < height; y++) { for (int x = 0; x < width; x++) { vector<double> window; for (int ky = -padding; ky <= padding; ky++) { for (int kx = -padding; kx <= padding; kx++) { int ny = y + ky; int nx = x + kx; if (ny >= 0 && ny < height && nx >= 0 && nx < width) { window.push_back(input[ny * width + nx]); } } } // 排序并取中值 sort(window.begin(), window.end()); output[y * width + x] = window[window.size() / 2]; } } return output; } // 高斯模糊 static vector<double> gaussianBlur(const vector<double>& input, int width, int height, double sigma = 1.0) { int kernelSize = 5; // 高斯核大小 int padding = kernelSize / 2; // 创建高斯核 vector<vector<double>> kernel(kernelSize, vector<double>(kernelSize)); double sum = 0.0; for (int y = -padding; y <= padding; y++) { for (int x = -padding; x <= padding; x++) { double value = exp(-(x*x + y*y) / (2 * sigma * sigma)); kernel[y + padding][x + padding] = value; sum += value; } } // 归一化核 for (int y = 0; y < kernelSize; y++) { for (int x = 0; x < kernelSize; x++) { kernel[y][x] /= sum; } } // 应用高斯核 vector<double> output(input.size(), 0.0); for (int y = 0; y < height; y++) { for (int x = 0; x < width; x++) { double value = 0.0; for (int ky = -padding; ky <= padding; ky++) { for (int kx = -padding; kx <= padding; kx++) { int ny = y + ky; int nx = x + kx; if (ny >= 0 && ny < height && nx >= 0 && nx < width) { value += input[ny * width + nx] * kernel[ky + padding][kx + padding]; } } } output[y * width + x] = value; } } return output; } // 二值化处理 static vector<double> threshold(const vector<double>& input, double thresholdValue = 0.5) { vector<double> output(input.size()); for (size_t i = 0; i < input.size(); i++) { output[i] = input[i] > thresholdValue ? 1.0 : 0.0; } return output; } // 对比度增强 static vector<double> enhanceContrast(const vector<double>& input, double minVal, double maxVal) { vector<double> output(input.size()); for (size_t i = 0; i < input.size(); i++) { output[i] = (input[i] - minVal) / (maxVal - minVal); output[i] = max(0.0, min(1.0, output[i])); // 限制在[0,1]范围内 } return output; } // 预处理图像(组合多种方法) static vector<double> preprocessImage(const vector<double>& input, int width, int height) { vector<double> processed = input; // 首先应用中值滤波去除噪声 processed = medianFilter(processed, width, height); // 应用高斯模糊平滑图像 processed = gaussianBlur(processed, width, height, 0.8); // 自动计算阈值 double sum = 0.0; for (double val : processed) { sum += val; } double mean = sum / processed.size(); double thresholdValue = mean * 1.2; // 稍微提高阈值 // 二值化处理 processed = threshold(processed, thresholdValue); return processed; } // 从外部图片加载并预处理图像数据 static vector<double> loadAndPreprocessImage(const QString& filePath, int width = 28, int height = 28) { QImage image; if (!image.load(filePath)) { qDebug() << "无法加载图像:" << filePath; return vector<double>(); } // 转换为灰度图 QImage grayImage = image.convertToFormat(QImage::Format_Grayscale8); // 缩放为指定尺寸 QImage scaledImage = grayImage.scaled(width, height, Qt::IgnoreAspectRatio, Qt::SmoothTransformation); vector<double> imageData(width * height); // 转换为灰度并归一化 for (int y = 0; y < height; ++y) { for (int x = 0; x < width; ++x) { QRgb pixel = scaledImage.pixel(x, y); // 计算灰度值并反转(MNIST是白底黑字,我们需要转换为黑底白字) double gray = 1.0 - (qRed(pixel) * 0.299 + qGreen(pixel) * 0.587 + qBlue(pixel) * 0.114) / 255.0; imageData[y * width + x] = gray; } } // 应用图像预处理(去噪) imageData = preprocessImage(imageData, width, height); return imageData; } }; // 神经网络类 class NeuralNetwork { private: // 网络结构参数 int inputSize; int hiddenSize; int outputSize; // 权重和偏置 vector<vector<double>> weightsInputHidden; vector<vector<double>> weightsHiddenOutput; vector<double> biasHidden; vector<double> biasOutput; // 学习率 double learningRate; // 随机数生成器 default_random_engine generator; normal_distribution<double> distribution; public: // 构造函数 NeuralNetwork(int input, int hidden, int output, double lr = 0.1) : inputSize(input), hiddenSize(hidden), outputSize(output), learningRate(lr), distribution(0.0, 1.0) { // 初始化权重和偏置 initializeWeights(); } // 初始化权重和偏置 void initializeWeights() { // 初始化输入层到隐藏层的权重 weightsInputHidden.resize(hiddenSize, vector<double>(inputSize)); for (int i = 0; i < hiddenSize; ++i) { for (int j = 0; j < inputSize; ++j) { weightsInputHidden[i][j] = distribution(generator) * 0.1; } } // 初始化隐藏层到输出层的权重 weightsHiddenOutput.resize(outputSize, vector<double>(hiddenSize)); for (int i = 0; i < outputSize; ++i) { for (int j = 0; j < hiddenSize; ++j) { weightsHiddenOutput[i][j] = distribution(generator) * 0.1; } } // 初始化偏置 biasHidden.resize(hiddenSize); for (int i = 0; i < hiddenSize; ++i) { biasHidden[i] = distribution(generator) * 0.1; } biasOutput.resize(outputSize); for (int i = 0; i < outputSize; ++i) { biasOutput[i] = distribution(generator) * 0.1; } } // Sigmoid激活函数 double sigmoid(double x) { return 1.0 / (1.0 + exp(-x)); } // Sigmoid的导数 double sigmoidDerivative(double x) { return x * (1.0 - x); } // Softmax激活函数 vector<double> softmax(const vector<double>& x) { vector<double> result(x.size()); double maxVal = *max_element(x.begin(), x.end()); double sum = 0.0; // 减去最大值防止数值溢出 for (size_t i = 0; i < x.size(); ++i) { result[i] = exp(x[i] - maxVal); sum += result[i]; } // 归一化 for (size_t i = 0; i < x.size(); ++i) { result[i] /= sum; } return result; } // 交叉熵损失 double crossEntropyLoss(const vector<double>& output, const vector<double>& target) { double loss = 0.0; for (size_t i = 0; i < output.size(); ++i) { // 添加一个小值防止log(0) loss -= target[i] * log(output[i] + 1e-15); } return loss; } // 前向传播 vector<double> forward(const vector<double>& input) { // 隐藏层计算 (使用Sigmoid) vector<double> hidden(hiddenSize); for (int i = 0; i < hiddenSize; ++i) { double sum = biasHidden[i]; for (int j = 0; j < inputSize; ++j) { sum += input[j] * weightsInputHidden[i][j]; } hidden[i] = sigmoid(sum); } // 输出层计算 (使用Softmax) vector<double> output(outputSize); for (int i = 0; i < outputSize; ++i) { double sum = biasOutput[i]; for (int j = 0; j < hiddenSize; ++j) { sum += hidden[j] * weightsHiddenOutput[i][j]; } output[i] = sum; // 先保存线性输出 } // 应用Softmax return softmax(output); } // 训练网络 void train(const vector<double>& input, const vector<double>& target) { // 前向传播 vector<double> hidden(hiddenSize); for (int i = 0; i < hiddenSize; ++i) { double sum = biasHidden[i]; for (int j = 0; j < inputSize; ++j) { sum += input[j] * weightsInputHidden[i][j]; } hidden[i] = sigmoid(sum); } // 输出层线性计算 vector<double> outputLinear(outputSize); for (int i = 0; i < outputSize; ++i) { double sum = biasOutput[i]; for (int j = 0; j < hiddenSize; ++j) { sum += hidden[j] * weightsHiddenOutput[i][j]; } outputLinear[i] = sum; } // 应用Softmax vector<double> output = softmax(outputLinear); // 计算输出层误差 (使用交叉熵损失 + Softmax的梯度) vector<double> outputErrors(outputSize); for (int i = 0; i < outputSize; ++i) { // 对于Softmax + 交叉熵损失,梯度是 (output - target) outputErrors[i] = output[i] - target[i]; } // 计算隐藏层误差 vector<double> hiddenErrors(hiddenSize); for (int i = 0; i < hiddenSize; ++i) { double error = 0.0; for (int j = 0; j < outputSize; ++j) { error += outputErrors[j] * weightsHiddenOutput[j][i]; } hiddenErrors[i] = error * sigmoidDerivative(hidden[i]); } // 更新隐藏层到输出层的权重和偏置 for (int i = 0; i < outputSize; ++i) { for (int j = 0; j < hiddenSize; ++j) { weightsHiddenOutput[i][j] -= learningRate * outputErrors[i] * hidden[j]; } biasOutput[i] -= learningRate * outputErrors[i]; } // 更新输入层到隐藏层的权重和偏置 for (int i = 0; i < hiddenSize; ++i) { for (int j = 0; j < inputSize; ++j) { weightsInputHidden[i][j] -= learningRate * hiddenErrors[i] * input[j]; } biasHidden[i] -= learningRate * hiddenErrors[i]; } } // 预测函数 int predict(const vector<double>& input) { vector<double> output = forward(input); return distance(output.begin(), max_element(output.begin(), output.end())); } // 获取预测置信度 vector<double> getConfidence(const vector<double>& input) { return forward(input); } // 计算准确率 double calculateAccuracy(const vector<vector<double>>& testImages, const vector<int>& testLabels) { if (testImages.empty() || testLabels.empty() || testImages.size() != testLabels.size()) { return 0.0; } int correct = 0; int total = min(1000, static_cast<int>(testImages.size())); // 限制测试数量以提高性能 for (int i = 0; i < total; ++i) { int prediction = predict(testImages[i]); if (prediction == testLabels[i]) { correct++; } } return static_cast<double>(correct) / total; } // 保存模型 void saveModel(const string& filename) { ofstream file(filename); if (file.is_open()) { // 保存网络结构 file << inputSize << " " << hiddenSize << " " << outputSize << " " << learningRate << "\n"; // 保存权重和偏置 for (int i = 0; i < hiddenSize; ++i) { for (int j = 0; j < inputSize; ++j) { file << weightsInputHidden[i][j] << " "; } file << "\n"; } for (int i = 0; i < outputSize; ++i) { for (int j = 0; j < hiddenSize; ++j) { file << weightsHiddenOutput[i][j] << " "; } file << "\n"; } for (int i = 0; i < hiddenSize; ++i) { file << biasHidden[i] << " "; } file << "\n"; for (int i = 0; i < outputSize; ++i) { file << biasOutput[i] << " "; } file << "\n"; file.close(); cout << "模型已保存到: " << filename << endl; } else { cerr << "无法保存模型到文件: " << filename << endl; } } // 加载模型 void loadModel(const string& filename) { ifstream file(filename); if (file.is_open()) { // 读取网络结构 file >> inputSize >> hiddenSize >> outputSize >> learningRate; // 重新分配内存 weightsInputHidden.resize(hiddenSize, vector<double>(inputSize)); weightsHiddenOutput.resize(outputSize, vector<double>(hiddenSize)); biasHidden.resize(hiddenSize); biasOutput.resize(outputSize); // 读取权重和偏置 for (int i = 0; i < hiddenSize; ++i) { for (int j = 0; j < inputSize; ++j) { file >> weightsInputHidden[i][j]; } } for (int i = 0; i < outputSize; ++i) { for (int j = 0; j < hiddenSize; ++j) { file >> weightsHiddenOutput[i][j]; } } for (int i = 0; i < hiddenSize; ++i) { file >> biasHidden[i]; } for (int i = 0; i < outputSize; ++i) { file >> biasOutput[i]; } file.close(); cout << "模型已从 " << filename <<"加载" << endl; } else { cerr << "无法加载模型文件: " << filename << endl; } } }; // 将标签转换为one-hot编码 vector<double> toOneHot(int label, int numClasses) { vector<double> oneHot(numClasses, 0.0); oneHot[label] = 1.0; return oneHot; } // 绘图画布类 class DrawingCanvas : public QWidget { Q_OBJECT public: explicit DrawingCanvas(QWidget *parent = nullptr) : QWidget(parent) { setFixedSize(400, 400); // 增大画布尺寸 setAutoFillBackground(true); QPalette palette; palette.setColor(QPalette::Window, Qt::white); setPalette(palette); clearCanvas(); } void clearCanvas() { canvasImage = QImage(400, 400, QImage::Format_RGB32); canvasImage.fill(Qt::white); update(); } QImage getImage() const { return canvasImage; } // 保存画布为图片 bool saveImage(const QString& fileName) { return canvasImage.save(fileName); } // 获取28x28像素的归一化图像数据 vector<double> getNormalizedImageData() { // 将绘图区域缩放为28x28 QImage scaledImage = canvasImage.scaled(28, 28, Qt::IgnoreAspectRatio, Qt::SmoothTransformation); vector<double> imageData(28 * 28); // 转换为灰度并归一化 for (int y = 0; y < 28; ++y) { for (int x = 0; x < 28; ++x) { QRgb pixel = scaledImage.pixel(x, y); // 计算灰度值并反转(MNIST是白底黑字,我们是黑底白字) double gray = 1.0 - (qRed(pixel) * 0.299 + qGreen(pixel) * 0.587 + qBlue(pixel) * 0.114) / 255.0; imageData[y * 28 + x] = gray; } } // 应用图像预处理(去噪) imageData = ImageProcessor::preprocessImage(imageData, 28, 28); return imageData; } protected: void mousePressEvent(QMouseEvent *event) override { if (event->button() == Qt::LeftButton) { lastPoint = event->pos(); drawing = true; } } void mouseMoveEvent(QMouseEvent *event) override { if ((event->buttons() & Qt::LeftButton) && drawing) { drawLineTo(event->pos()); } } void mouseReleaseEvent(QMouseEvent *event) override { if (event->button() == Qt::LeftButton && drawing) { drawLineTo(event->pos()); drawing = false; } } void paintEvent(QPaintEvent *event) override { Q_UNUSED(event); QPainter painter(this); painter.drawImage(0, 0, canvasImage); } private: void drawLineTo(const QPoint &endPoint) { QPainter painter(&canvasImage); painter.setPen(QPen(Qt::black, 20, Qt::SolidLine, Qt::RoundCap, Qt::RoundJoin)); // 增加笔刷大小 painter.drawLine(lastPoint, endPoint); int rad = 10; update(QRect(lastPoint, endPoint).normalized().adjusted(-rad, -rad, rad, rad)); lastPoint = endPoint; } QImage canvasImage; QPoint lastPoint; bool drawing = false; }; // 主窗口类 class MainWindow : public QMainWindow { Q_OBJECT public: MainWindow(QWidget *parent = nullptr) : QMainWindow(parent) { setupUI(); setupNeuralNetwork(); } ~MainWindow() { delete nn; } private slots: void onClearButtonClicked() { drawingCanvas->clearCanvas(); predictionLabel->setText("预测: -"); confidenceLabel->setText("置信度: -"); timeLabel->setText("用时: -"); // 清空置信度条 for (int i = 0; i < 10; ++i) { confidenceBars[i]->setValue(0); confidenceLabels[i]->setText("0%"); } } void onRecognizeButtonClicked() { if (!nn) { QMessageBox::warning(this, "错误", "神经网络未初始化"); return; } // 开始计时 QElapsedTimer timer; timer.start(); vector<double> imageData = drawingCanvas->getNormalizedImageData(); int prediction = nn->predict(imageData); // 获取预测耗时 qint64 elapsed = timer.elapsed(); // 获取置信度 vector<double> output = nn->getConfidence(imageData); double confidence = output[prediction] * 100; predictionLabel->setText(QString("预测: %1").arg(prediction)); confidenceLabel->setText(QString("置信度: %1%").arg(confidence, 0, 'f', 2)); timeLabel->setText(QString("用时: %1 毫秒").arg(elapsed)); // 更新置信度条 for (int i = 0; i < 10; ++i) { int value = static_cast<int>(output[i] * 100); confidenceBars[i]->setValue(value); confidenceLabels[i]->setText(QString("%1%").arg(value)); } } void onTrainButtonClicked() { // 获取MNIST数据文件路径 QString imagesPath = QFileDialog::getOpenFileName(this, "选择训练图像文件", "", "MNIST图像文件 (*.idx3-ubyte)"); QString labelsPath = QFileDialog::getOpenFileName(this, "选择训练标签文件", "", "MNIST标签文件 (*.idx1-ubyte)"); if (imagesPath.isEmpty() || labelsPath.isEmpty()) { QMessageBox::warning(this, "错误", "请选择MNIST数据文件"); return; } // 询问是否使用验证集 QMessageBox::StandardButton reply; reply = QMessageBox::question(this, "验证集", "是否使用验证集评估训练过程?", QMessageBox::Yes | QMessageBox::No); bool useValidation = (reply == QMessageBox::Yes); QString validationImagesPath, validationLabelsPath; int validationNumberOfImages = 0, validationImageSize = 0; vector<vector<double>> validationImages; vector<int> validationLabels; if (useValidation) { validationImagesPath = QFileDialog::getOpenFileName(this, "选择验证图像文件", "", "MNIST图像文件 (*.idx3-ubyte)"); validationLabelsPath = QFileDialog::getOpenFileName(this, "选择验证标签文件", "", "MNIST标签文件 (*.idx1-ubyte)"); if (validationImagesPath.isEmpty() || validationLabelsPath.isEmpty()) { QMessageBox::warning(this, "错误", "请选择验证集数据文件"); return; } // 读取验证集数据 validationImages = readMNISTImages(validationImagesPath.toStdString(), validationNumberOfImages, validationImageSize); validationLabels = readMNISTLabels(validationLabelsPath.toStdString(), validationNumberOfImages); if (validationImages.empty() || validationLabels.empty()) { QMessageBox::warning(this, "错误", "无法读取验证集数据"); useValidation = false; } } // 开始计时 QElapsedTimer timer; timer.start(); // 显示进度对话框 QProgressDialog progress("训练中...", "取消", 0, 100, this); progress.setWindowModality(Qt::WindowModal); progress.show(); // 读取MNIST数据 int numberOfImages = 0; int imageSize = 0; vector<vector<double>> images = readMNISTImages(imagesPath.toStdString(), numberOfImages, imageSize); int numberOfLabels = 0; vector<int> labels = readMNISTLabels(labelsPath.toStdString(), numberOfLabels); if (images.empty() || labels.empty() || numberOfImages != numberOfLabels) { QMessageBox::warning(this, "错误", "无法读取MNIST数据或数据不匹配"); return; } // 训练网络参数 int epochs = 5; int batchSize = 10; int numBatches = numberOfImages / batchSize; progress.setMaximum(epochs * numBatches); for (int epoch = 0; epoch < epochs; ++epoch) { for (int i = 0; i < numBatches; ++i) { // 检查是否取消 if (progress.wasCanceled()) { break; } // 训练一个batch for (int j = 0; j < batchSize; ++j) { int index = i * batchSize + j; if (index >= numberOfImages) { break; } vector<double> input = images[index]; vector<double> target = toOneHot(labels[index], 10); nn->train(input, target); } progress.setValue(epoch * numBatches + i); QApplication::processEvents(); // 处理UI事件 } if (progress.wasCanceled()) { break; } } progress.close(); // 获取训练耗时 qint64 elapsed = timer.elapsed(); qint64 minutes = elapsed / 60000; qint64 seconds = (elapsed % 60000) / 1000; qint64 milliseconds = elapsed % 1000; // 计算最终准确率 double finalAccuracy = 0.0; if (numberOfImages > 1000) { int sampleSize = 1000; vector<vector<double>> sampleImages; vector<int> sampleLabels; for (int i = 0; i < sampleSize; ++i) { int index = rand() % numberOfImages; sampleImages.push_back(images[index]); sampleLabels.push_back(labels[index]); } finalAccuracy = nn->calculateAccuracy(sampleImages, sampleLabels); } else { finalAccuracy = nn->calculateAccuracy(images, labels); } // 计算验证集准确率 double validationAccuracy = 0.0; if (useValidation) { validationAccuracy = nn->calculateAccuracy(validationImages, validationLabels); accuracyLabel->setText(QString("训练准确率: %1%\n验证准确率: %2%") .arg(finalAccuracy * 100, 0, 'f', 2) .arg(validationAccuracy * 100, 0, 'f', 2)); } else { accuracyLabel->setText(QString("训练准确率: %1%").arg(finalAccuracy * 100, 0, 'f', 2)); } // 显示训练时间 timeLabel->setText(QString("训练用时: %1分 %2秒 %3毫秒").arg(minutes).arg(seconds).arg(milliseconds)); QMessageBox::information(this, "成功", QString("训练完成!\n最终准确率: %1%\n训练用时: %2分 %3秒 %4毫秒") .arg(finalAccuracy * 100, 0, 'f', 2) .arg(minutes).arg(seconds).arg(milliseconds)); } void onLoadModelButtonClicked() { QString fileName = QFileDialog::getOpenFileName(this, "加载模型", "", "模型文件 (*.txt)"); if (!fileName.isEmpty()) { nn->loadModel(fileName.toStdString()); QMessageBox::information(this, "成功", "模型加载成功!"); } } void onSaveModelButtonClicked() { QString fileName = QFileDialog::getSaveFileName(this, "保存模型", "", "模型文件 (*.txt)"); if (!fileName.isEmpty()) { nn->saveModel(fileName.toStdString()); QMessageBox::information(this, "成功", "模型保存成功!"); } } void onTestButtonClicked() { QString imagesPath = QFileDialog::getOpenFileName(this, "选择测试图像文件", "", "MNIST图像文件 (*.idx3-ubyte)"); QString labelsPath = QFileDialog::getOpenFileName(this, "选择测试标签文件", "", "MNIST标签文件 (*.idx1-ubyte)"); if (imagesPath.isEmpty() || labelsPath.isEmpty()) { QMessageBox::warning(this, "错误", "请选择测试数据文件"); return; } // 开始计时 QElapsedTimer timer; timer.start(); // 显示进度对话框 QProgressDialog progress("测试中...", "取消", 0, 100, this); progress.setWindowModality(Qt::WindowModal); progress.show(); // 读取测试数据 int numberOfImages = 0; int imageSize = 0; vector<vector<double>> testImages = readMNISTImages(imagesPath.toStdString(), numberOfImages, imageSize); int numberOfLabels = 0; vector<int> testLabels = readMNISTLabels(labelsPath.toStdString(), numberOfLabels); if (testImages.empty() || testLabels.empty() || numberOfImages != numberOfLabels) { QMessageBox::warning(this, "错误", "无法读取测试数据或数据不匹配"); return; } progress.setMaximum(numberOfImages); // 计算准确率 int correct = 0; for (int i = 0; i < numberOfImages; ++i) { if (progress.wasCanceled()) { break; } int prediction = nn->predict(testImages[i]); if (prediction == testLabels[i]) { correct++; } progress.setValue(i + 1); QApplication::processEvents(); } progress.close(); // 获取测试耗时 qint64 elapsed = timer.elapsed(); qint64 minutes = elapsed / 60000; qint64 seconds = (elapsed % 60000) / 1000; qint64 milliseconds = elapsed % 1000; double accuracy = static_cast<double>(correct) / numberOfImages; accuracyLabel->setText(QString("测试准确率: %1%").arg(accuracy * 100, 0, 'f', 2)); timeLabel->setText(QString("测试用时: %1分 %2秒 %3毫秒").arg(minutes).arg(seconds).arg(milliseconds)); QMessageBox::information(this, "测试结果", QString("测试完成!\n准确率: %1% (%2/%3)\n测试用时: %4分 %5秒 %6毫秒") .arg(accuracy * 100, 0, 'f', 2) .arg(correct) .arg(numberOfImages) .arg(minutes).arg(seconds).arg(milliseconds)); } void onSaveImageButtonClicked() { QString fileName = QFileDialog::getSaveFileName(this, "保存图像", "", "PNG图像 (*.png);;JPEG图像 (*.jpg *.jpeg)"); if (!fileName.isEmpty()) { if (drawingCanvas->saveImage(fileName)) { QMessageBox::information(this, "成功", "图像保存成功!"); } else { QMessageBox::warning(this, "错误", "无法保存图像!"); } } } void onLoadImageButtonClicked() { QString fileName = QFileDialog::getOpenFileName(this, "加载图像", "", "图像文件 (*.png *.jpg *.jpeg *.bmp)"); if (fileName.isEmpty()) { return; } // 开始计时 QElapsedTimer timer; timer.start(); // 加载并预处理图像 vector<double> imageData = ImageProcessor::loadAndPreprocessImage(fileName); if (imageData.empty()) { QMessageBox::warning(this, "错误", "无法加载或处理图像!"); return; } // 预测 int prediction = nn->predict(imageData); // 获取预测耗时 qint64 elapsed = timer.elapsed(); // 获取置信度 vector<double> output = nn->getConfidence(imageData); double confidence = output[prediction] * 100; predictionLabel->setText(QString("预测: %1").arg(prediction)); confidenceLabel->setText(QString("置信度: %1%").arg(confidence, 0, 'f', 2)); timeLabel->setText(QString("用时: %1 毫秒").arg(elapsed)); // 更新置信度条 for (int i = 0; i < 10; ++i) { int value = static_cast<int>(output[i] * 100); confidenceBars[i]->setValue(value); confidenceLabels[i]->setText(QString("%1%").arg(value)); } // 显示加载的图像 QPixmap pixmap(fileName); if (!pixmap.isNull()) { QMessageBox::information(this, "图像已加载", QString("已成功加载图像并识别!\n预测结果: %1\n置信度: %2%") .arg(prediction) .arg(confidence, 0, 'f', 2)); } } private: void setupUI() { // 创建中央部件 QWidget *centralWidget = new QWidget(this); setCentralWidget(centralWidget); // 创建主布局 QHBoxLayout *mainLayout = new QHBoxLayout(centralWidget); // 左侧区域 - 画布 drawingCanvas = new DrawingCanvas(this); mainLayout->addWidget(drawingCanvas, 50); // 50%宽度给画布 // 右侧区域 - 控制面板 QWidget *rightPanel = new QWidget(this); QVBoxLayout *rightLayout = new QVBoxLayout(rightPanel); // 创建按钮 QPushButton *clearButton = new QPushButton("清除画布", this); QPushButton *recognizeButton = new QPushButton("识别数字", this); QPushButton *trainButton = new QPushButton("训练网络", this); QPushButton *testButton = new QPushButton("测试网络", this); QPushButton *loadModelButton = new QPushButton("加载模型", this); QPushButton *saveModelButton = new QPushButton("保存模型", this); QPushButton *saveImageButton = new QPushButton("保存图像", this); QPushButton *loadImageButton = new QPushButton("加载图像", this); // 设置按钮样式 QString buttonStyle = "QPushButton { padding: 8px; font-size: 12px; }"; clearButton->setStyleSheet(buttonStyle); recognizeButton->setStyleSheet(buttonStyle); trainButton->setStyleSheet(buttonStyle); testButton->setStyleSheet(buttonStyle); loadModelButton->setStyleSheet(buttonStyle); saveModelButton->setStyleSheet(buttonStyle); saveImageButton->setStyleSheet(buttonStyle); loadImageButton->setStyleSheet(buttonStyle); // 创建识别结果区域 QWidget *resultWidget = new QWidget(this); QVBoxLayout *resultLayout = new QVBoxLayout(resultWidget); predictionLabel = new QLabel("预测: -", this); confidenceLabel = new QLabel("置信度: -", this); accuracyLabel = new QLabel("准确率: -", this); timeLabel = new QLabel("用时: -", this); QFont labelFont = predictionLabel->font(); labelFont.setPointSize(14); predictionLabel->setFont(labelFont); confidenceLabel->setFont(labelFont); accuracyLabel->setFont(labelFont); timeLabel->setFont(labelFont); predictionLabel->setAlignment(Qt::AlignCenter); confidenceLabel->setAlignment(Qt::AlignCenter); accuracyLabel->setAlignment(Qt::AlignCenter); timeLabel->setAlignment(Qt::AlignCenter); resultLayout->addWidget(predictionLabel); resultLayout->addWidget(confidenceLabel); resultLayout->addWidget(accuracyLabel); resultLayout->addWidget(timeLabel); // 创建置信度显示区域 QGroupBox *confidenceGroup = new QGroupBox("数字置信度", this); QGridLayout *confidenceLayout = new QGridLayout(confidenceGroup); for (int i = 0; i < 10; ++i) { QLabel *digitLabel = new QLabel(QString::number(i), this); digitLabel->setAlignment(Qt::AlignCenter); digitLabel->setFixedWidth(20); QProgressBar *bar = new QProgressBar(this); bar->setRange(0, 100); bar->setValue(0); bar->setTextVisible(false); bar->setFixedHeight(15); QLabel *percentLabel = new QLabel("0%", this); percentLabel->setAlignment(Qt::AlignCenter); percentLabel->setFixedWidth(40); confidenceLayout->addWidget(digitLabel, i, 0); confidenceLayout->addWidget(bar, i, 1); confidenceLayout->addWidget(percentLabel, i, 2); confidenceBars.push_back(bar); confidenceLabels.push_back(percentLabel); } confidenceGroup->setLayout(confidenceLayout); // 将控件添加到右侧布局 rightLayout->addWidget(clearButton); rightLayout->addWidget(recognizeButton); rightLayout->addWidget(trainButton); rightLayout->addWidget(testButton); rightLayout->addWidget(loadModelButton); rightLayout->addWidget(saveModelButton); rightLayout->addWidget(saveImageButton); rightLayout->addWidget(loadImageButton); rightLayout->addWidget(resultWidget); rightLayout->addWidget(confidenceGroup); rightLayout->addStretch(); mainLayout->addWidget(rightPanel, 50); // 50%宽度给控制面板 // 连接按钮信号和槽 connect(clearButton, &QPushButton::clicked, this, &MainWindow::onClearButtonClicked); connect(recognizeButton, &QPushButton::clicked, this, &MainWindow::onRecognizeButtonClicked); connect(trainButton, &QPushButton::clicked, this, &MainWindow::onTrainButtonClicked); connect(testButton, &QPushButton::clicked, this, &MainWindow::onTestButtonClicked); connect(loadModelButton, &QPushButton::clicked, this, &MainWindow::onLoadModelButtonClicked); connect(saveModelButton, &QPushButton::clicked, this, &MainWindow::onSaveModelButtonClicked); connect(saveImageButton, &QPushButton::clicked, this, &MainWindow::onSaveImageButtonClicked); connect(loadImageButton, &QPushButton::clicked, this, &MainWindow::onLoadImageButtonClicked); setWindowTitle("手写数字识别系统"); resize(1000, 500); } void setupNeuralNetwork() { // 初始化神经网络 nn = new NeuralNetwork(784, 128, 10, 0.1); // 尝试加载预训练模型 try { nn->loadModel("mnist_model.txt"); } catch (...) { qDebug() << "无法加载预训练模型,使用随机初始化的权重"; } } DrawingCanvas *drawingCanvas; QLabel *predictionLabel; QLabel *confidenceLabel; QLabel *accuracyLabel; QLabel *timeLabel; NeuralNetwork *nn; vector<QProgressBar*> confidenceBars; vector<QLabel*> confidenceLabels; }; int main(int argc, char *argv[]) { QApplication app(argc, argv); MainWindow window; window.show(); return app.exec(); } #include "main.moc" 加载图像实现一张图片识别多个数字的功能
最新发布
09-17
import numpy as np import matplotlib.pyplot as plt import pandas as pd # 设置中文字体和负号显示 plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"] plt.rcParams["axes.unicode_minus"] = False from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split from sklearn.svm import SVC from sklearn.tree import DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.neural_network import MLPClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.naive_bayes import GaussianNB from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score from sklearn.preprocessing import StandardScaler import tkinter as tk from tkinter import ttk, filedialog, messagebox from PIL import Image, ImageDraw import cv2 import os import csv # 尝试导入XGBoost和LightGBM try: import xgboost as xgb except ImportError: xgb = None print("警告: 未安装XGBoost库,无法使用XGBoost模型") try: import lightgbm as lgb except ImportError: lgb = None print("警告: 未安装LightGBM库,无法使用LightGBM模型") # 定义模型元数据常量,优化LightGBM参数 MODEL_METADATA = { 'svm': ('支持向量机(SVM)', SVC, StandardScaler), 'dt': ('决策树(DT)', DecisionTreeClassifier, None), 'rf': ('随机森林(RF)', RandomForestClassifier, None), 'xgb': ('XGBoost(XGB)', xgb.XGBClassifier if xgb else None, None), 'lgb': ('LightGBM(LGB)', lgb.LGBMClassifier if lgb else None, None), 'mlp': ('多层感知机(MLP)', MLPClassifier, StandardScaler), 'knn': ('K最近邻(KNN)', KNeighborsClassifier, StandardScaler), 'nb': ('高斯朴素贝叶斯(NB)', GaussianNB, None), } def get_split_data(digits_dataset): """ 提取重复的数据集划分逻辑 :param digits_dataset: 手写数字数据集 :return: 划分后的训练集和测试集 """ X, y = digits_dataset.data, digits_dataset.target return train_test_split(X, y, test_size=0.3, random_state=42) class ModelFactory: @staticmethod def create_model(model_type): """ 创建模型和数据标准化器 :param model_type: 模型类型 :return: 模型和数据标准化器 """ name, model_cls, scaler_cls = MODEL_METADATA[model_type] if not model_cls: raise ImportError(f"{name}模型依赖库未安装") model_params = { 'svm': {'probability': True, 'random_state': 42}, 'dt': {'random_state': 42}, 'rf': {'n_estimators': 100, 'random_state': 42}, 'xgb': {'objective': 'multi:softmax' if xgb else 'multi:softprob', 'random_state': 42}, 'lgb': {'objective': 'multiclass', 'random_state': 42, 'num_class': 10, 'max_depth': 5, 'min_child_samples': 10, 'learning_rate': 0.1, 'force_col_wise': True}, 'mlp': {'hidden_layer_sizes': (100, 50), 'max_iter': 1000, 'random_state': 42}, 'knn': {'n_neighbors': 5, 'weights': 'distance'}, 'nb': {}, }.get(model_type, {'random_state': 42}) model = model_cls(**model_params) scaler = scaler_cls() if scaler_cls else None return model, scaler @staticmethod def train_model(model, X_train, y_train, scaler=None, model_type=None): """ 训练模型 :param model: 模型 :param X_train: 训练集特征 :param y_train: 训练集标签 :param scaler: 数据标准化器 :param model_type: 模型类型 :return: 训练好的模型 """ if scaler: X_train = scaler.fit_transform(X_train) if model_type == 'lgb' and isinstance(X_train, np.ndarray): X_train = pd.DataFrame(X_train) model.fit(X_train, y_train) return model @staticmethod def evaluate_model(model, X_test, y_test, scaler=None, model_type=None): """ 评估模型 :param model: 模型 :param X_test: 测试集特征 :param y_test: 测试集标签 :param scaler: 数据标准化器 :param model_type: 模型类型 :return: 模型准确率 """ if scaler: X_test = scaler.transform(X_test) if model_type == 'lgb' and isinstance(X_test, np.ndarray) and hasattr(model, 'feature_name_'): X_test = pd.DataFrame(X_test, columns=model.feature_name_) y_pred = model.predict(X_test) return accuracy_score(y_test, y_pred) @staticmethod def train_and_evaluate(model_type, X_train, y_train, X_test, y_test): """ 训练并评估模型 :param model_type: 模型类型 :param X_train: 训练集特征 :param y_train: 训练集标签 :param X_test: 测试集特征 :param y_test: 测试集标签 :return: 训练好的模型、数据标准化器和准确率 """ try: model, scaler = ModelFactory.create_model(model_type) model = ModelFactory.train_model(model, X_train, y_train, scaler, model_type) accuracy = ModelFactory.evaluate_model(model, X_test, y_test, scaler, model_type) return model, scaler, accuracy except Exception as e: print(f"模型 {model_type} 训练/评估错误: {str(e)}") raise e def evaluate_all_models(digits_dataset): """ 评估所有可用模型 :param digits_dataset: 手写数字数据集 :return: 模型评估结果 """ print("\n=== 模型评估 ===") X_train, X_test, y_train, y_test = get_split_data(digits_dataset) results = [] for model_type, (name, _, _) in MODEL_METADATA.items(): print(f"评估模型: {name} ({model_type})") if not MODEL_METADATA[model_type][1]: results.append({"模型名称": name, "准确率": "N/A"}) continue try: model, scaler, accuracy = ModelFactory.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) results.append({"模型名称": name, "准确率": f"{accuracy:.4f}"}) except Exception as e: results.append({"模型名称": name, "准确率": f"错误: {str(e)}"}) results.sort(key=lambda x: float(x["准确率"]) if isinstance(x["准确率"], str) and x["准确率"].replace('.', '', 1).isdigit() else -1, reverse=True) print(pd.DataFrame(results)) return results class HandwritingBoard: def __init__(self, root, model_factory, digits): self.root = root self.root.title("手写数字识别系统 (含模型性能对比)") self.root.geometry("1000x600") # 减小主窗口尺寸 self.model_factory = model_factory self.digits = digits self.model_cache = {} self.current_model = None self.scaler = None self.current_model_type = None self.has_drawn = False self.last_x, self.last_y = 0, 0 self.custom_data = [] self.drawing = False self.data_dir = "custom_digits_data" if not os.path.exists(self.data_dir): os.makedirs(self.data_dir) # 初始化画布尺寸相关变量 self.canvas_width = 600 self.canvas_height = 600 self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255) self.draw_obj = ImageDraw.Draw(self.image) self.create_widgets() self.init_default_model() self.canvas.bind("<Configure>", self.on_canvas_resize) # 绑定窗口大小改变事件 def create_widgets(self): """创建界面组件""" # 顶部控制栏 top_frame = tk.Frame(self.root) top_frame.pack(fill=tk.X, padx=10, pady=5) # 减小边距 tk.Label(top_frame, text="选择模型:", font=("Arial", 10)).pack(side=tk.LEFT, padx=5) # 减小字体和边距 self.available_models = [] for key in MODEL_METADATA: name = MODEL_METADATA[key][0] if MODEL_METADATA[key][1]: self.available_models.append((key, name)) self.model_combobox = ttk.Combobox( top_frame, values=[name for _, name in self.available_models], state="readonly", width=15, # 减小宽度 font=("Arial", 10) # 减小字体 ) self.model_combobox.current(0) self.model_combobox.bind("<<ComboboxSelected>>", self.on_model_select) self.model_combobox.pack(side=tk.LEFT, padx=5) # 减小边距 # 中间内容区域 middle_frame = tk.Frame(self.root) middle_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=5) # 减小边距 # 左侧绘图区域 canvas_frame = tk.Frame(middle_frame) canvas_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=(0, 10)) # 减小边距 self.canvas = tk.Canvas(canvas_frame, bg="white") self.canvas.pack(fill=tk.BOTH, expand=True) self.canvas.bind("<Button-1>", self.start_draw) self.canvas.bind("<B1-Motion>", self.draw) self.canvas.bind("<ButtonRelease-1>", self.stop_draw) # 添加绘制提示 self.canvas.create_text(self.canvas_width / 2, self.canvas_height / 2, text="绘制数字", fill="gray", font=("Arial", 16)) # 减小字体 # 右侧控制面板 - 使用grid布局 control_frame = tk.Frame(middle_frame) control_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True) # 使用grid布局排列右侧组件 control_frame.grid_columnconfigure(0, weight=1) control_frame.grid_columnconfigure(1, weight=1) # 当前模型 current_model_frame = tk.LabelFrame(control_frame, text="当前模型", font=("Arial", 10, "bold")) # 减小字体 current_model_frame.grid(row=0, column=0, columnspan=2, sticky="ew", pady=(0, 8), padx=3) # 减小边距 self.model_label = tk.Label(current_model_frame, text="支持向量机(SVM)", font=("Arial", 12), relief=tk.RAISED, padx=8) # 减小字体和边距 self.model_label.pack(fill=tk.X, pady=5) # 减小边距 # 操作按钮 (左侧) button_frame = tk.LabelFrame(control_frame, text="操作", font=("Arial", 10, "bold")) # 减小字体 button_frame.grid(row=1, column=0, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距 tk.Button(button_frame, text="识别", command=self.recognize, width=12, height=1, font=("Arial", 10)).pack(fill=tk.X, pady=4) # 减小尺寸和边距 tk.Button(button_frame, text="清除", command=self.clear_canvas, width=12, height=1, font=("Arial", 10)).pack(fill=tk.X, pady=4) # 减小尺寸和边距 tk.Button(button_frame, text="样本", command=self.show_samples, width=12, height=1, font=("Arial", 10)).pack(fill=tk.X, pady=4) # 减小尺寸和边距 tk.Button(button_frame, text="对比图表", command=self.show_performance_chart, width=12, height=1, font=("Arial", 10)).pack(fill=tk.X, pady=4) # 减小尺寸和边距 # 训练集管理 (右侧) train_set_frame = tk.LabelFrame(control_frame, text="训练集管理", font=("Arial", 10, "bold")) # 减小字体 train_set_frame.grid(row=1, column=1, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距 tk.Button(train_set_frame, text="保存为训练样本", command=self.save_as_training_sample, width=12, height=1, font=("Arial", 10)).pack( # 减小尺寸和边距 fill=tk.X, pady=4 ) tk.Button(train_set_frame, text="保存全部训练集", command=self.save_all_training_data, width=12, height=1, font=("Arial", 10)).pack( # 减小尺寸和边距 fill=tk.X, pady=4 ) tk.Button(train_set_frame, text="加载训练集", command=self.load_training_data, width=12, height=1, font=("Arial", 10)).pack( # 减小尺寸和边距 fill=tk.X, pady=4 ) # 识别结果 result_frame = tk.LabelFrame(control_frame, text="识别结果", font=("Arial", 10, "bold")) # 减小字体 result_frame.grid(row=2, column=0, columnspan=2, sticky="ew", pady=(0, 8), padx=3) # 减小边距 self.result_label = tk.Label(result_frame, text="请绘制数字", font=("Arial", 24)) # 减小字体 self.result_label.pack(pady=5) # 减小边距 self.prob_label = tk.Label(result_frame, text="", font=("Arial", 10)) # 减小字体 self.prob_label.pack(pady=3) # 减小边距 self.debug_label = tk.Label(result_frame, text="", font=("Arial", 9), wraplength=250) # 减小字体和宽度 self.debug_label.pack(pady=3) # 减小边距 # 置信度可视化 (左侧) self.confidence_frame = tk.LabelFrame(control_frame, text="识别置信度", font=("Arial", 10, "bold")) # 减小字体 self.confidence_frame.grid(row=3, column=0, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距 self.confidence_canvas = tk.Canvas(self.confidence_frame, bg="white", height=80) # 减小高度 self.confidence_canvas.pack(fill=tk.BOTH, expand=True, padx=3, pady=3) # 减小边距 # 可能的数字列表 (左侧) self.candidates_frame = tk.LabelFrame(control_frame, text="可能的数字", font=("Arial", 10, "bold")) # 减小字体 self.candidates_frame.grid(row=4, column=0, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距 self.candidates_tree = ttk.Treeview(self.candidates_frame, columns=("数字", "概率"), show="headings") self.candidates_tree.column("数字", width=70, anchor=tk.CENTER) # 减小列宽 self.candidates_tree.column("概率", width=70, anchor=tk.CENTER) # 减小列宽 self.candidates_tree.heading("数字", text="数字") self.candidates_tree.heading("概率", text="概率") self.candidates_tree.pack(fill=tk.BOTH, expand=True, padx=3, pady=3) # 减小边距 # 模型性能对比 (右侧,与置信度和候选数字并列) self.performance_frame = tk.LabelFrame(control_frame, text="模型性能对比", font=("Arial", 10, "bold")) # 减小字体 self.performance_frame.grid(row=3, column=1, rowspan=2, sticky="nsew", pady=(0, 8), padx=(3, 3)) # 减小边距 self.create_performance_table() def create_performance_table(self): """创建模型性能表格""" for widget in self.performance_frame.winfo_children(): widget.destroy() columns = ("模型名称", "准确率") self.performance_tree = ttk.Treeview(self.performance_frame, columns=columns, show="headings") self.performance_tree.column("模型名称", width=120, anchor=tk.W) # 减小列宽 self.performance_tree.column("准确率", width=80, anchor=tk.CENTER) # 减小列宽 self.performance_tree.heading("模型名称", text="模型名称") self.performance_tree.heading("准确率", text="准确率") scrollbar = ttk.Scrollbar(self.performance_frame, orient=tk.VERTICAL, command=self.performance_tree.yview) self.performance_tree.configure(yscroll=scrollbar.set) self.performance_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) scrollbar.pack(side=tk.RIGHT, fill=tk.Y) self.load_performance_data() def load_performance_data(self): """加载模型性能数据""" results = evaluate_all_models(self.digits) for item in self.performance_tree.get_children(): self.performance_tree.delete(item) for i, result in enumerate(results): tag = "highlight" if i == 0 and isinstance(result["准确率"], str) and result["准确率"].replace('.', '', 1).isdigit() else "" self.performance_tree.insert("", tk.END, values=(result["模型名称"], result["准确率"]), tags=(tag,)) self.performance_tree.tag_configure("highlight", background="#e6f7ff", font=("Arial", 9, "bold")) # 减小字体 def show_performance_chart(self): """显示模型性能对比图表""" results = evaluate_all_models(self.digits) valid_results = [] for result in results: try: accuracy = float(result["准确率"]) valid_results.append((result["模型名称"], accuracy)) except (ValueError, TypeError): continue if not valid_results: messagebox.showinfo("提示", "没有可用的模型性能数据来生成图表") return valid_results.sort(key=lambda x: x[1], reverse=True) plt.figure(figsize=(10, 6)) # 减小图表尺寸 models, accuracies = zip(*valid_results) bars = plt.barh(models, accuracies, color='skyblue') plt.xlabel('准确率', fontsize=10) # 减小字体 plt.ylabel('模型', fontsize=10) # 减小字体 plt.title('各模型在手写数字识别上的性能对比', fontsize=12) # 减小字体 plt.xlim(0, 1.05) for bar in bars: width = bar.get_width() plt.text(width + 0.01, bar.get_y() + bar.get_height() / 2, f'{width:.4f}', ha='left', va='center', fontsize=8) # 减小字体 plt.tight_layout() plt.show() plt.close() def start_draw(self, event): """开始绘制事件处理""" self.drawing = True self.last_x, self.last_y = event.x, event.y def draw(self, event): """绘制事件处理""" if not self.drawing: return x, y = event.x, event.y self.canvas.create_oval(x - 8, y - 8, x + 8, y + 8, fill="black") # 减小绘制笔触 self.draw_obj.line([self.last_x, self.last_y, x, y], fill=0, width=16) # 减小绘制笔触 self.last_x, self.last_y = x, y def stop_draw(self, event): """停止绘制事件处理""" self.drawing = False self.has_drawn = True def clear_canvas(self): """清除画布""" self.canvas.delete("all") # 更新画布尺寸相关状态 self.canvas_width = self.canvas.winfo_width() self.canvas_height = self.canvas.winfo_height() self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255) self.draw_obj = ImageDraw.Draw(self.image) self.result_label.config(text="请绘制数字") self.prob_label.config(text="") self.debug_label.config(text="") self.canvas.create_text(self.canvas_width / 2, self.canvas_height / 2, text="绘制数字", fill="gray", font=("Arial", 16)) # 减小字体 self.has_drawn = False self.clear_confidence_display() def clear_confidence_display(self): """清除置信度显示""" self.confidence_canvas.delete("all") for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) def preprocess_image(self): """预处理手写数字图像""" img_array = np.array(self.image) img_array = cv2.GaussianBlur(img_array, (5, 5), 0) _, img_array = cv2.threshold(img_array, 127, 255, cv2.THRESH_BINARY_INV) contours, _ = cv2.findContours(img_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: self.debug_label.config(text="未检测到有效数字,请重新绘制") return np.zeros(64) c = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(c) digit = img_array[y:y + h, x:x + w] size = max(w, h) padded = np.ones((size, size), dtype=np.uint8) * 255 offset_x = (size - w) // 2 offset_y = (size - h) // 2 padded[offset_y:offset_y + h, offset_x:offset_x + w] = digit resized = cv2.resize(padded, (8, 8), interpolation=cv2.INTER_AREA) normalized = 16 - (resized / 255 * 16).astype(np.uint8) return normalized.flatten() def recognize(self): """识别手写数字并显示置信度和候选数字""" if not self.has_drawn: self.debug_label.config(text="请先绘制数字再识别") return if self.current_model_type is None: self.debug_label.config(text="模型类型未正确设置,请重新加载模型") return if self.current_model is None: self.debug_label.config(text="模型未加载,请选择并加载模型") return img = self.preprocess_image() if img.sum() == 0: self.clear_confidence_display() return img_input = img.reshape(1, -1) try: if self.scaler: img_input = self.scaler.transform(img_input) if self.current_model_type == 'lgb' and hasattr(self.current_model, 'feature_name_'): img_input = pd.DataFrame(img_input, columns=self.current_model.feature_name_) pred = self.current_model.predict(img_input)[0] # 显示识别结果 self.result_label.config(text=f"识别结果: {pred}") # 检查模型是否支持概率预测 if hasattr(self.current_model, 'predict_proba'): probs = self.current_model.predict_proba(img_input)[0] confidence = probs[pred] # 显示置信度 self.prob_label.config(text=f"置信度: {confidence:.2%}") # 更新置信度可视化 self.update_confidence_display(confidence) # 显示前3个候选数字 top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3] self.update_candidates_display(top3) # 更新调试信息 prob_text = "\n".join([f"数字 {i}: 概率 {p:.2%}" for i, p in top3]) self.debug_label.config(text=prob_text) else: self.prob_label.config(text="置信度: 该模型不支持概率输出") self.debug_label.config(text="") self.clear_confidence_display() except Exception as e: self.debug_label.config(text=f"识别错误: {str(e)}") self.clear_confidence_display() print("识别异常:", e) def update_confidence_display(self, confidence): """更新置信度可视化""" self.confidence_canvas.delete("all") # 获取画布宽度用于动态调整 canvas_width = self.confidence_canvas.winfo_width() or 300 # 减小默认宽度 # 绘制背景条 self.confidence_canvas.create_rectangle(15, 15, canvas_width - 15, 45, fill="#f0f0f0", outline="gray") # 减小尺寸 # 绘制置信度条 bar_width = int((canvas_width - 30) * confidence) color = self.get_confidence_color(confidence) self.confidence_canvas.create_rectangle(15, 15, 15 + bar_width, 45, fill=color, outline="") # 减小尺寸 # 绘制置信度文本 self.confidence_canvas.create_text((canvas_width) / 2, 30, text=f"置信度: {confidence:.2%}", font=("Arial", 9)) # 减小字体 # 绘制数字0-100刻度 for i in range(0, 11): x_pos = 15 + i * (canvas_width - 30) / 10 self.confidence_canvas.create_line(x_pos, 45, x_pos, 50, width=1) # 减小尺寸 if i % 2 == 0: # 每20%显示一个数字 self.confidence_canvas.create_text(x_pos, 60, text=f"{i * 10}", font=("Arial", 7)) # 减小字体 def get_confidence_color(self, confidence): """根据置信度返回对应的颜色""" if confidence >= 0.9: return "#2ecc71" # 绿色 (高置信度) elif confidence >= 0.7: return "#f1c40f" # 黄色 (中等置信度) else: return "#e74c3c" # 红色 (低置信度) def update_candidates_display(self, candidates): """更新候选数字显示""" # 清空现有项 for item in self.candidates_tree.get_children(): self.candidates_tree.delete(item) # 添加新项 for digit, prob in candidates: # 去掉高亮标签 tag = "" self.candidates_tree.insert("", tk.END, values=(digit, f"{prob:.2%}"), tags=(tag,)) def show_samples(self): """显示手写数字样本""" plt.figure(figsize=(10, 5)) # 减小图表尺寸 for i in range(10): plt.subplot(2, 5, i + 1) sample_idx = np.where(self.digits.target == i)[0][0] plt.imshow(self.digits.images[sample_idx], cmap="gray") plt.title(f"数字 {i}", fontsize=10) # 减小字体 plt.axis("off") plt.tight_layout() plt.show() plt.close() def on_model_select(self, event): """模型选择事件处理""" selected_name = self.model_combobox.get() model_type = {v: k for k, v in self.available_models}[selected_name] self.change_model(model_type) def change_model(self, model_type): """切换模型""" print(f"触发 change_model,选中模型键: {model_type}") model_name = MODEL_METADATA.get(model_type, (model_type,))[0] if model_type in self.model_cache: self.current_model, self.scaler, accuracy, self.current_model_type = self.model_cache[model_type] self.model_label.config( text=f"{model_name} (准确率:{accuracy:.4f})" ) self.debug_label.config(text=f"已从缓存加载 {model_name}") return print(f"\n=== 开始加载 {model_name} 模型 ===") X_train, X_test, y_train, y_test = get_split_data(self.digits) try: self.current_model, self.scaler, accuracy = ModelFactory.train_and_evaluate( model_type, X_train, y_train, X_test, y_test ) self.current_model_type = model_type self.model_cache[model_type] = (self.current_model, self.scaler, accuracy, self.current_model_type) self.model_label.config( text=f"{model_name} (准确率:{accuracy:.4f})" ) self.debug_label.config(text=f"模型加载完成,测试集准确率: {accuracy:.4f}") self.clear_canvas() print(f"=== {model_name} 加载完成,准确率 {accuracy:.4f} ===\n") self.load_performance_data() except Exception as e: self.debug_label.config(text=f"模型加载失败: {str(e)}") print(f"加载异常: {e}\n") def init_default_model(self): """初始化默认模型""" default_model_type = 'svm' self.change_model(default_model_type) def save_as_training_sample(self): """保存手写数字作为训练样本""" if not self.has_drawn: self.debug_label.config(text="请先绘制数字再保存") return img = self.preprocess_image() if img.sum() == 0: self.debug_label.config(text="未检测到有效数字,无法保存") return label_window = tk.Toplevel(self.root) label_window.title("输入数字标签") label_window.geometry("300x150") # 减小窗口尺寸 tk.Label(label_window, text="请输入您绘制的数字 (0-9):", font=("Arial", 10)).pack(pady=10) # 减小字体和边距 entry = tk.Entry(label_window, font=("Arial", 12), width=8) # 减小字体和宽度 entry.pack(pady=5) # 减小边距 entry.focus_set() def save_with_label(): try: label = int(entry.get()) if not (0 <= label <= 9): raise ValueError("标签必须是0到9之间的数字") self.custom_data.append((img.tolist(), label)) self.debug_label.config(text=f"已保存数字 {label} 到训练集 (当前共有 {len(self.custom_data)} 个样本)") label_window.destroy() except ValueError as e: self.debug_label.config(text=f"输入错误: {str(e)}") tk.Button(label_window, text="保存", command=save_with_label, width=10, height=1, font=("Arial", 10)).pack(pady=8) # 减小尺寸和边距 label_window.grab_set() def save_all_training_data(self): """保存所有训练数据""" if not self.custom_data: self.debug_label.config(text="没有训练数据可保存") return file_path = filedialog.asksaveasfilename( defaultextension=".csv", filetypes=[("CSV文件", "*.csv"), ("所有文件", "*.*")], initialfile="custom_digits_training.csv", title="保存训练集" ) if not file_path: return try: with open(file_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) writer.writerow([f'pixel{i}' for i in range(64)] + ['label']) for img_data, label in self.custom_data: writer.writerow(img_data + [label]) self.debug_label.config(text=f"已保存 {len(self.custom_data)} 个训练样本到 {file_path}") except Exception as e: self.debug_label.config(text=f"保存失败: {str(e)}") print(f"保存训练集异常: {e}") def load_training_data(self): """加载训练数据""" file_path = filedialog.askopenfilename( filetypes=[("CSV文件", "*.csv"), ("所有文件", "*.*")], title="加载训练集" ) if not file_path: return try: self.custom_data = [] with open(file_path, 'r', newline='', encoding='utf-8') as f: reader = csv.reader(f) next(reader) # 跳过标题行 for row in reader: if len(row) < 65: # 确保数据完整 continue img_data = [int(pixel) for pixel in row[:64]] label = int(row[64]) self.custom_data.append((img_data, label)) self.debug_label.config(text=f"已从 {file_path} 加载 {len(self.custom_data)} 个训练样本") except Exception as e: self.debug_label.config(text=f"加载失败: {str(e)}") print(f"加载训练集异常: {e}") def on_canvas_resize(self, event): """处理画布大小改变事件""" # 忽略初始尺寸为1的事件 if event.width <= 1 or event.height <= 1: return # 更新画布尺寸 self.canvas_width = event.width self.canvas_height = event.height # 重新创建图像并居中绘制提示文本 self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255) self.draw_obj = ImageDraw.Draw(self.image) # 清除并重新绘制提示 self.canvas.delete("all") self.canvas.create_text(self.canvas_width / 2, self.canvas_height / 2, text="绘制数字", fill="gray", font=("Arial", 16)) # 减小字体 def run(self): """运行主循环""" self.root.mainloop() if __name__ == "__main__": digits = load_digits() root = tk.Tk() app = HandwritingBoard(root, ModelFactory, digits) app.run() 帮我优化代码
06-23
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值