import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
import tkinter as tk
from tkinter import Button, messagebox
from PIL import Image, ImageDraw, ImageOps
import os
# ====================== 训练部分 ====================== #
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差
])
# 加载训练数据
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) # 增大batch size
# 加载测试数据
test_dataset = datasets.MNIST('data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1000)
# 定义改进模型
class Improved_MNIST_CNN(nn.Module):
def __init__(self):
super(Improved_MNIST_CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.layer3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc1 = nn.Linear(128 * 3 * 3, 512)
self.dropout1 = nn.Dropout(0.4)
self.fc2 = nn.Linear(512, 128)
self.dropout2 = nn.Dropout(0.3)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.dropout1(x)
x = self.fc2(x)
x = self.dropout2(x)
x = self.fc3(x)
return x
# 初始化模型、损失函数和优化器
model = Improved_MNIST_CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7) # 学习率衰减
# 训练函数
def train_model(model, train_loader, test_loader, optimizer, scheduler, epochs=15):
model.train()
best_accuracy = 0.0
for epoch in range(epochs):
running_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')
# 更新学习率
scheduler.step()
# 每个epoch结束后在测试集上评估
accuracy = test_model(model, test_loader)
avg_loss = running_loss / len(train_loader)
print(f'Epoch {epoch+1} completed, Avg Loss: {avg_loss:.6f}, Test Accuracy: {accuracy:.2f}%')
# 保存最佳模型
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save(model.state_dict(), 'mnist_model_best.pth')
print(f"Saved best model with accuracy: {best_accuracy:.2f}%")
return best_accuracy
# 测试函数
def test_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100. * correct / total
model.train()
return accuracy
# 训练模型
def run_training():
print("开始训练模型...")
best_accuracy = train_model(model, train_loader, test_loader, optimizer, scheduler, epochs=15)
print(f"训练完成! 最佳准确率: {best_accuracy:.2f}%")
# 保存最终模型
torch.save(model.state_dict(), 'mnist_model_final.pth')
print("模型已保存为: mnist_model_final.pth")
return best_accuracy
# ====================== 识别部分 ====================== #
# 加载训练好的模型
def load_model(model_path='mnist_model_best.pth'):
model = Improved_MNIST_CNN()
try:
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
print(f"成功加载模型: {model_path}")
return model
else:
print(f"警告: 找不到模型文件 '{model_path}'")
return None
except Exception as e:
print(f"加载模型时出错: {e}")
return None
# 手写数字识别应用
class DigitRecognizer:
def __init__(self, model):
self.model = model
self.transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 创建主窗口
self.root = tk.Tk()
self.root.title("MNIST手写数字识别")
self.root.geometry("400x500")
# 标题
self.title_label = tk.Label(self.root, text="手写数字识别", font=("Arial", 16))
self.title_label.pack(pady=10)
# 创建画布
self.canvas_width = 280
self.canvas_height = 280
self.canvas = tk.Canvas(
self.root,
width=self.canvas_width,
height=self.canvas_height,
bg="white",
cursor="cross"
)
self.canvas.pack(pady=10)
# 绑定鼠标事件
self.canvas.bind("<B1-Motion>", self.draw)
# 创建PIL图像
self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255)
self.draw_img = ImageDraw.Draw(self.image)
# 按钮框架
button_frame = tk.Frame(self.root)
button_frame.pack(pady=10)
# 识别按钮
self.recognize_btn = Button(
button_frame,
text="识别",
command=self.recognize,
width=10,
height=2,
bg="#4CAF50",
fg="white",
font=("Arial", 12)
)
self.recognize_btn.pack(side=tk.LEFT, padx=10)
# 清除按钮
self.clear_btn = Button(
button_frame,
text="清除",
command=self.reset,
width=10,
height=2,
bg="#F44336",
fg="white",
font=("Arial", 12)
)
self.clear_btn.pack(side=tk.LEFT, padx=10)
# 结果标签
self.result_label = tk.Label(
self.root,
text="结果: 请书写数字并点击'识别'",
font=("Arial", 14),
pady=10
)
self.result_label.pack()
# 状态栏
self.status_var = tk.StringVar()
self.status_var.set("就绪")
self.status_bar = tk.Label(
self.root,
textvariable=self.status_var,
bd=1,
relief=tk.SUNKEN,
anchor=tk.W
)
self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)
print("请在画布上书写数字,然后点击'识别'按钮...")
def reset(self):
self.canvas.delete("all")
self.image = Image.new("L", (self.canvas_width, self.canvas_height), 255)
self.draw_img = ImageDraw.Draw(self.image)
self.result_label.config(text="结果: 请书写数字并点击'识别'")
self.status_var.set("画布已清除")
def draw(self, event):
x, y = event.x, event.y
r = 10 # 笔触半径
self.canvas.create_oval(x-r, y-r, x+r, y+r, fill="black", outline="black")
self.draw_img.ellipse([x-r, y-r, x+r, y+r], fill=0)
def preprocess(self):
# 反转颜色:黑底白字 -> 白底黑字 (符合MNIST格式)
inverted_img = ImageOps.invert(self.image)
# 找到数字的边界
bbox = inverted_img.getbbox()
if not bbox:
return None
# 裁剪数字
cropped = inverted_img.crop(bbox)
# 计算缩放比例,保持宽高比
width, height = cropped.size
max_dim = max(width, height)
scale = 20.0 / max_dim # 缩放至20像素内
# 创建新图像并居中放置
new_width = int(width * scale)
new_height = int(height * scale)
resized = cropped.resize((new_width, new_height), Image.LANCZOS)
# 创建28x28空白图像
final_img = Image.new("L", (28, 28), 0) # 背景为黑色
# 计算放置位置(居中)
x_offset = (28 - new_width) // 2
y_offset = (28 - new_height) // 2
final_img.paste(resized, (x_offset, y_offset))
return final_img
def recognize(self):
if self.model is None:
messagebox.showerror("错误", "模型未加载成功,请先训练模型")
return
processed_img = self.preprocess()
if processed_img is None:
self.status_var.set("错误: 未检测到书写内容")
messagebox.showwarning("警告", "未检测到书写内容,请在画布上书写数字")
return
# 转换为张量
tensor = self.transform(processed_img).unsqueeze(0)
# 预测
with torch.no_grad():
output = self.model(tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
_, predicted = torch.max(output, 1)
digit = predicted.item()
confidence = probabilities[digit].item() * 100
self.result_label.config(text=f"识别结果: {digit} (置信度: {confidence:.1f}%)")
self.status_var.set(f"识别完成: {digit} (置信度: {confidence:.1f}%)")
# 显示处理后的图像(可选)
# processed_img.show()
# 主函数
def main():
# 检查模型是否存在
model_path = 'mnist_model_best.pth'
model = None
if os.path.exists(model_path):
model = load_model(model_path)
else:
print("未找到预训练模型,开始训练新模型...")
run_training()
model = load_model(model_path)
if model:
# 创建识别器
recognizer = DigitRecognizer(model)
recognizer.root.mainloop()
if __name__ == "__main__":
main()这是手写数字识别代码,写实训报告大约8000字包括代码