前言
相信对于每一个刚刚上手深度学习的孩子来说,利用mnist数据集来训练一个CNN是再好不过的学习demo了。
本文使用 pytorch 来动手搭建一个卷积神经网络来训练和预测手写数字。通过本文,你将了解到pytorch的一些功能:
- 高效加载数据集;
- 简单灵活地设计神经网络;
- 了解对训练和泛化有帮助的网络结构tricks(如batchnorm,dropout)
- 学习优化器(一般用adam);
- 神经网络的损失函数(一般用交叉熵);
- 学习率的动态调节(学习率的动态变化);
- pytorch 训练过程(尤其是批量进行的训练方式mini-batch)
- pytorch 预测的过程
接下来就开始啦,每一部分的代码我尽量搭配详细的注释,让你快速理解,轻松上手pytorch!
引入库函数
import torch # pytorch 最基本模块
import torch.nn as nn # pytorch中最重要的模块,封装了神经网络相关的函数
import torch.nn.functional as F # 提供了一些常用的函数,如softmax
import torch.optim as optim # 优化模块,封装了求解模型的一些优化器,如Adam SGD
from torch.optim import lr_scheduler # 学习率调整器,在训练过程中合理变动学习率
from torchvision import transforms #pytorch 视觉库中提供了一些数据变换的接口
from torchvision import datasets #pytorch 视觉库提供了加载数据集的接口
预设超参数
# 预设网络超参数 (所谓超参数就是可以人为设定的参数
BATCH_SIZE= 64 # 由于使用批量训练的方法,需要定义每批的训练的样本数目
EPOCHS= 2 # 总共训练迭代的次数
# 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 0.1 # 设定初始的学习率
加载数据集
像MNIST这么知名的数据集,pytorch居然内置了对应的加载接口,真的优秀!不过第一次使用我们会下载数据集到一个文件夹中,以后就可以直接读取该文件夹内部的数据了。这里我们使用dataloader迭代器来加载数据集,题外话:迭代器的作用可以减少内存的占用。
# 加载训练集
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,)) # 数据规范化到正态分布
])),
batch_size=BATCH_SIZE, shuffle=True) # 指明批量大小,打乱,这是处于后续训练的需要。
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])<