系列文章目录
- pytorch MNIST数据集无法正常加载的解决办法( HTTP Error 503: Service Unavailable)
- Python 手写数字识别的实现(pytorch框架) 超详细版本
- pytorch 手写数字识别 新网络设计和学习率探索
前言
这是中国科学院大学深度学习的课程作业
本文详细介绍了如何构建LeNet-5神经网络用于手写数字识别。
文中大量的代码解释包含在代码行后的注释中,请注意查看。
在pytorch 手写数字识别 新网络设计和学习率探索中我探索了一个新结构的CNN网络用于手写数字识别,可参看。
下面的代码在谷歌云盘的colab上运行,也可以在jupyter notebook上运行
文本参考了用PyTorch实现MNIST手写数字识别(非常详细)中的部分内容。
分步骤解释
- 首先导入需要的包
import torch
import torchvision
import torchvision.transforms as transforms
- 设置超参数,每个参数解释见注释
n_epochs = 5 # 模型训练5轮
log_interval = 30 #控制打印频率的,设n = 30*batch_size,即n张图后打印一次进度
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 根据设备是否支持GPU来选择硬件
size = 32 # 对输入图片进行处理,拉伸为32*32的图片,这是为了复刻手写数字识别的神经网络,其输入为32*32的灰度图像
learn_rate = 0.03 # 学习率
momentum = 0.1 # 动量
- 加载数据集(见blog:pytorch集成的数据集无法访问,采用了指定url方法)
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz
from torchvision.datasets import MNIST
transform = transforms.Compose(
[ transforms.Resize(size), transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))]) # 正则化处理,相当于z-score
trainset = MNIST(root = './', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = MNIST(root = './', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=True, num_workers=2)
# classes = ('1', '2', '3', '4', '5', '6', '7', '8', '9', '0')
这一步需要说明的是,对加载来的图像进行了增强处理(transforms.Compose()),主要是为了防止过拟合、
详细解释推荐博客:https://medium.com/@CinnamonAITaiwan/cnn%E5%85%A5%E9%96%80-%E5%9C%96%E5%83%8F%E5%A2%9E%E5%BC%B7-fa654d36dafc
后面将通过进一步实验探索图像增强对结果的影响
- 打印测试集的标签和tensor大小
examples = enumerate(testloader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)
这一步目的是查看数据是否符合我们的要求
结果:
tensor([3, 0, 9, 7, 6, 1, 4, 9, 3, 1, 9, 6, 1, 2, 4, 2, 6, 0, 1, 4, 7, 3, 7, 7,
7, 1, 6, 6, 9, 0, 8, 6, 9, 2, 8, 3, 7, 0, 3, 5, 1, 1, 5, 6, 1, 6, 8, 6,
5, 2, 5, 1, 4, 8, 8, 1, 4, 2, 1, 8, 6, 2, 4, 9, 3, 0, 7, 5, 2, 2, 4, 8,
7, 4, 9, 2, 2, 7, 7, 8, 1, 7, 4, 8, 7, 8, 9, 1, 5, 5, 4, 8, 0, 5, 4, 9,
5, 1, 1, 3, 9