第P1周:Pytorch实现mnist手写数字识别

基于Pytorch构建CNN网络并训练


FROM


我的环境

  • 语言环境:Python 3.8.10
  • 开发工具:Jupyter Lab
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

1. 准备知识

1.1 检查环境

import torch  # 导入PyTorch库,用于构建深度学习模型
import torch.nn as nn  # 导入torch.nn模块,包含构建神经网络所需的类和函数
import matplotlib.pyplot as plt  # 导入matplotlib.pyplot模块,用于数据可视化
import torchvision  # 导入torchvision库,包含处理图像和视频的工具和预训练模型

# 设置硬件设备,如果有GPU则使用,没有则使用cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检查系统是否有可用的GPU,如果有则使用GPU,否则使用CPU
device  # 打印当前设备,以确认是使用GPU还是CPU

输出:
输出

1.2 数据导入

torchvision.datasets.MNIST
torchvision.datasets是Pytorch自带的一个数据库,可以通过代码在线下载数据。

这行代码是用于加载MNIST数据集的。MNIST是一个大型的手写数字数据库,常用于训练各种图像处理系统。代码中的参数含义如下:

  • root: 数据集的本地存储路径。
  • train: 布尔值,如果为True,则加载训练集;如果为False,则加载测试集。
  • transform: 一个可选的transform,用于对图像进行预处理操作,例如缩放、裁剪等。这里设置为None,表示不进行任何预处理。
  • target_transform: 一个可选的transform,用于对标签进行预处理操作。这里设置为None,表示不进行任何预处理。
  • download: 布尔值,如果为True,则从网上下载数据集并存储到root指定的路径;如果为False,则假定数据集已经存在于该路径下。
torchvision.datasets.MNIST(  # 调用torchvision中的MNIST数据集类
    root,  # 指定数据集的存储路径
    train=True,  # 指定加载训练集,如果设置为False,则加载测试集
    transform=None,  # 指定如何对图像进行预处理,例如缩放、裁剪等,这里设置为None,表示不进行预处理
    target_transform=None,  # 指定如何对标签进行预处理,这里设置为None,表示不进行预处理
    download=False  # 指定是否需要下载数据集,如果数据集已经存在于root路径下,则设置为False
)
train_ds = torchvision.datasets.MNIST(  # 创建MNIST训练集数据对象
    'data',  # 指定数据集的存储路径
    train=True,  # 指定加载训练集
    transform=torchvision.transforms.ToTensor(),  # 指定预处理操作,将图像转换为Tensor
    download=True  # 如果数据集不存在,则下载数据集
)

test_ds = torchvision.datasets.MNIST(  # 创建MNIST测试集数据对象
    'data',  # 指定数据集的存储路径
    train=False,  # 指定加载测试集
    transform=torchvision.transforms.ToTensor(),  # 指定预处理操作,将图像转换为Tensor
    download=True  # 如果数据集不存在,则下载数据集
)

输出:
[图片]

torch.utils.data.DataLoader
torch.utils.data.DataLoader是Pytorch自带的一个数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。

torch.utils.data.DataLoader(
    dataset,  # 要加载的数据集对象
    batch_size=1,  # 每个批次的样本数,默认为1
    shuffle=None,  # 是否在每个epoch开始时打乱数据,默认为None,根据sampler和batch_sampler的设置决定
    sampler=None,  # 定义从数据集中抽取样本的策略,如果设置了sampler,则shuffle应为False
    batch_sampler=None,  # 与sampler类似,但返回的是样本索引的批次
    num_workers=0,  # 加载数据时使用的进程数,0表示单进程,正整数表示多进程
    collate_fn=None,  # 决定如何将多个样本合并为一个批次的函数,默认为'default_collate'
    pin_memory=False,  # 如果设置为True,则在将数据从CPU传输到GPU之前,先将数据放到CUDA固定内存中
    drop_last=False,  # 如果为True,则在数据集大小不能被batch_size整除时,丢弃最后一个不完整的批次
    timeout=0,  # 如果大于0,则设置为非零超时时间,用于数据加载
    worker_init_fn=None,  # 每个工作进程启动时调用的函数
    multiprocessing_context=None,  # 用于数据加载进程的多进程上下文
    generator=None,  # 当使用随机采样时使用的随机数生成器
    *,  # 确保以下参数必须作为关键字参数传递
    prefetch_factor=2,  # 数据加载器的预取因子,用于控制预取的数据量
    persistent_workers=False,  # 如果为True,则在所有数据加载完成后,工作进程不会退出
    pin_memory_device='',  # 指定用于pin_memory的设备,默认为CPU
)
  • dataset: 你想要加载
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值