参考文档:https://mp.weixin.qq.com/s/1TtPWYqVkj2Gaa-3QrEG1A
这篇文章是在一个大家经常见到的数据集 MNIST 上实现一个简单的 CNN。我们会基于上一篇文章中的分类器,来讨论实现一个 CNN,需要在之前的内容上做出哪些升级。
这篇笔记的内容包含三个部分:读取 pytorch 自带的数据集并分割;实现一个 CNN 的网络结构;完成训练。这三个部分合起来完成了一个简单的浅层卷积神经网络,在 MNIST 上进行训练和测试。
1、读取自带数据集并分割
train_data = torchvision.datasets.MNIST(
root='./mnist',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
)
以这个为例,我们就可以知道如何从 pytorch 中直接使用自带的数据集。而且 pytorch 包含了很多常用的数据集(我们比较熟悉的如 MNIST,cifar家族,ImageNet 等),所以熟悉一下如何使用自带的数据集也是非常有帮助的。
导入自带的数据集,按照上面代码的读取路径为:torchvision.datasets,在这一步后面跟上选择自己想要的数据集,这里我们选择了 MNIST。
这里的几个参数也很简单,root 是读取数据的路径,如果没有就去下载;train 为 True 表示读取的是训练集,反之就是测试集;transform 表示了将图片转换为 tensor;最后的 download 参数来设置是否从网上下载数据集。
import torch.utils.data as Data
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = torchvision.datasets.MNIST(
root='./mnist',
train=False,
transform=torchvision.transforms.ToTensor()
)
test_x = torch.unsqueeze(test_data.data, dim=1)/255.
test_y = test_data.targets
利用 torch 中 data 模块的 DataLoader 函数,我们可以设置一个数据加载器。设置的参数有三个,dataset 是前面读入的数据集;batch_size 就是表面意思,每批处理的数据的数量;shuffle 表示是否打乱数据。
下面的 test_data 和前面读取 train_data 是相同的方法,只不过将参数 train 调整为 False。接着 test_x 通过 unsqueeze 为数据加了一个维度,然后对所有数据除