pytorch 实战 | 动手设计CNN+MNIST手写体数字识别

本文是使用PyTorch实现CNN模型识别MNIST手写数字的教程,涵盖了数据加载、网络设计、训练与预测。通过这个教程,你将学习到PyTorch的基本操作,包括数据集处理、卷积神经网络结构、训练流程、预测过程和技巧如BatchNorm和Dropout。虽然只训练了两个epoch,但模型已经展现出良好的准确度。继续训练和应用数据增强将有助于提升模型性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

相信对于每一个刚刚上手深度学习的孩子来说,利用mnist数据集来训练一个CNN是再好不过的学习demo了。

本文使用 pytorch 来动手搭建一个卷积神经网络来训练和预测手写数字。通过本文,你将了解到pytorch的一些功能:

  1. 高效加载数据集;
  2. 简单灵活地设计神经网络;
  3. 了解对训练和泛化有帮助的网络结构tricks(如batchnorm,dropout)
  4. 学习优化器(一般用adam);
  5. 神经网络的损失函数(一般用交叉熵);
  6. 学习率的动态调节(学习率的动态变化);
  7. pytorch 训练过程(尤其是批量进行的训练方式mini-batch)
  8. pytorch 预测的过程

接下来就开始啦,每一部分的代码我尽量搭配详细的注释,让你快速理解,轻松上手pytorch!

引入库函数

import torch                     # pytorch 最基本模块
import torch.nn as nn            # pytorch中最重要的模块,封装了神经网络相关的函数
import torch.nn.functional as F  # 提供了一些常用的函数,如softmax
import torch.optim as optim      # 优化模块,封装了求解模型的一些优化器,如Adam SGD
from torch.optim import lr_scheduler # 学习率调整器,在训练过程中合理变动学习率
from torchvision import transforms  #pytorch 视觉库中提供了一些数据变换的接口
from torchvision import datasets  #pytorch 视觉库提供了加载数据集的接口

预设超参数

# 预设网络超参数 (所谓超参数就是可以人为设定的参数

BATCH_SIZE= 64 # 由于使用批量训练的方法,需要定义每批的训练的样本数目

EPOCHS= 2      # 总共训练迭代的次数

# 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

learning_rate = 0.1  # 设定初始的学习率

加载数据集

像MNIST这么知名的数据集,pytorch居然内置了对应的加载接口,真的优秀!不过第一次使用我们会下载数据集到一个文件夹中,以后就可以直接读取该文件夹内部的数据了。这里我们使用dataloader迭代器来加载数据集,题外话:迭代器的作用可以减少内存的占用。

# 加载训练集
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, 
                    transform=transforms.Compose([
                        transforms.ToTensor(), 
                        transforms.Normalize(mean=(0.5,), std=(0.5,)) # 数据规范化到正态分布
                    ])),
    batch_size=BATCH_SIZE, shuffle=True) # 指明批量大小,打乱,这是处于后续训练的需要。

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,)) 
                    ])<
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值