pytorch学习笔记(2):在MNIST上实现一个CNN

本文介绍了如何在MNIST数据集上使用PyTorch实现一个卷积神经网络(CNN)。内容包括:读取PyTorch自带的MNIST数据集并分割训练集和测试集,定义CNN的网络结构,以及完成训练过程。通过Sequential简化网络结构,使用nn.Conv2d、nn.ReLU和nn.MaxPool2d等构建CNN,并在训练过程中每50步输出训练误差和测试精度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

参考文档:https://mp.weixin.qq.com/s/1TtPWYqVkj2Gaa-3QrEG1A

这篇文章是在一个大家经常见到的数据集 MNIST 上实现一个简单的 CNN。我们会基于上一篇文章中的分类器,来讨论实现一个 CNN,需要在之前的内容上做出哪些升级。

这篇笔记的内容包含三个部分:读取 pytorch 自带的数据集并分割;实现一个 CNN 的网络结构;完成训练。这三个部分合起来完成了一个简单的浅层卷积神经网络,在 MNIST 上进行训练和测试。

1、读取自带数据集并分割

train_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

以这个为例,我们就可以知道如何从 pytorch 中直接使用自带的数据集。而且 pytorch 包含了很多常用的数据集(我们比较熟悉的如 MNIST,cifar家族,ImageNet 等),所以熟悉一下如何使用自带的数据集也是非常有帮助的。

导入自带的数据集,按照上面代码的读取路径为:torchvision.datasets,在这一步后面跟上选择自己想要的数据集,这里我们选择了 MNIST。

这里的几个参数也很简单,root 是读取数据的路径,如果没有就去下载;train 为 True 表示读取的是训练集,反之就是测试集;transform 表示了将图片转换为 tensor;最后的 download 参数来设置是否从网上下载数据集。

import torch.utils.data as Data
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=False,
    transform=torchvision.transforms.ToTensor()
)
test_x = torch.unsqueeze(test_data.data, dim=1)/255.
test_y = test_data.targets

利用 torch 中 data 模块的 DataLoader 函数,我们可以设置一个数据加载器。设置的参数有三个,dataset 是前面读入的数据集;batch_size 就是表面意思,每批处理的数据的数量;shuffle 表示是否打乱数据。

下面的 test_data 和前面读取 train_data 是相同的方法,只不过将参数 train 调整为 False。接着 test_x 通过 unsqueeze 为数据加了一个维度,然后对所有数据除

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值