《手写数字识别实战:从逻辑回归到CNN的进阶之路(完整代码解析)》
## 📚 前言
手写数字识别(MNIST)是机器学习领域的"Hello World"。本文将从最简单的逻辑回归模型出发,逐步实现ANN和CNN,通过**完整代码+可视化训练过程**,带你直观感受不同模型的性能差异。文末附完整代码和数据集!
---
## 🛠️ 环境准备
```python
# 基础库
import torch
import torch.nn as nn
from torchvision import datasets, transforms
# 可视化
import matplotlib.pyplot as plt
%matplotlib inline
# 设置GPU加速
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
📈 模型效果对比
| 模型 |
最高准确率 |
训练时间(15 epochs) |
| 逻辑回归 |
85.85% |
~2分钟 |
| ANN |
96.26% |
~2分钟 |
| CNN |
98.26% |
~2分钟 |
一、逻辑回归:基础入门
1. 核心代码
class LogisticRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
2. 训练过程
Epoch: 15/15 | Iteration: 9500 | Loss: 0.5245 | Accuracy: 85.85%
二、增强版ANN(三隐藏层)