FROM
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
我的环境
- 语言环境:Python 3.8.10
- 开发工具:Jupyter Lab
- 深度学习环境:
- torch==1.12.1+cu113
- torchvision==0.13.1+cu113
1. 准备知识
1.1 检查环境
import torch # 导入PyTorch库,用于构建深度学习模型
import torch.nn as nn # 导入torch.nn模块,包含构建神经网络所需的类和函数
import matplotlib.pyplot as plt # 导入matplotlib.pyplot模块,用于数据可视化
import torchvision # 导入torchvision库,包含处理图像和视频的工具和预训练模型
# 设置硬件设备,如果有GPU则使用,没有则使用cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检查系统是否有可用的GPU,如果有则使用GPU,否则使用CPU
device # 打印当前设备,以确认是使用GPU还是CPU
输出:

1.2 数据导入
torchvision.datasets.MNIST
torchvision.datasets是Pytorch自带的一个数据库,可以通过代码在线下载数据。
这行代码是用于加载MNIST数据集的。MNIST是一个大型的手写数字数据库,常用于训练各种图像处理系统。代码中的参数含义如下:
root: 数据集的本地存储路径。train: 布尔值,如果为True,则加载训练集;如果为False,则加载测试集。transform: 一个可选的transform,用于对图像进行预处理操作,例如缩放、裁剪等。这里设置为None,表示不进行任何预处理。target_transform: 一个可选的transform,用于对标签进行预处理操作。这里设置为None,表示不进行任何预处理。download: 布尔值,如果为True,则从网上下载数据集并存储到root指定的路径;如果为False,则假定数据集已经存在于该路径下。
torchvision.datasets.MNIST( # 调用torchvision中的MNIST数据集类
root, # 指定数据集的存储路径
train=True, # 指定加载训练集,如果设置为False,则加载测试集
transform=None, # 指定如何对图像进行预处理,例如缩放、裁剪等,这里设置为None,表示不进行预处理
target_transform=None, # 指定如何对标签进行预处理,这里设置为None,表示不进行预处理
download=False # 指定是否需要下载数据集,如果数据集已经存在于root路径下,则设置为False
)
train_ds = torchvision.datasets.MNIST( # 创建MNIST训练集数据对象
'data', # 指定数据集的存储路径
train=True, # 指定加载训练集
transform=torchvision.transforms.ToTensor(), # 指定预处理操作,将图像转换为Tensor
download=True # 如果数据集不存在,则下载数据集
)
test_ds = torchvision.datasets.MNIST( # 创建MNIST测试集数据对象
'data', # 指定数据集的存储路径
train=False, # 指定加载测试集
transform=torchvision.transforms.ToTensor(), # 指定预处理操作,将图像转换为Tensor
download=True # 如果数据集不存在,则下载数据集
)
输出:
![[图片]](https://i-blog.csdnimg.cn/direct/371de800046e497e881d3f0ee6e7e19a.png)
torch.utils.data.DataLoader
torch.utils.data.DataLoader是Pytorch自带的一个数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
torch.utils.data.DataLoader(
dataset, # 要加载的数据集对象
batch_size=1, # 每个批次的样本数,默认为1
shuffle=None, # 是否在每个epoch开始时打乱数据,默认为None,根据sampler和batch_sampler的设置决定
sampler=None, # 定义从数据集中抽取样本的策略,如果设置了sampler,则shuffle应为False
batch_sampler=None, # 与sampler类似,但返回的是样本索引的批次
num_workers=0, # 加载数据时使用的进程数,0表示单进程,正整数表示多进程
collate_fn=None, # 决定如何将多个样本合并为一个批次的函数,默认为'default_collate'
pin_memory=False, # 如果设置为True,则在将数据从CPU传输到GPU之前,先将数据放到CUDA固定内存中
drop_last=False, # 如果为True,则在数据集大小不能被batch_size整除时,丢弃最后一个不完整的批次
timeout=0, # 如果大于0,则设置为非零超时时间,用于数据加载
worker_init_fn=None, # 每个工作进程启动时调用的函数
multiprocessing_context=None, # 用于数据加载进程的多进程上下文
generator=None, # 当使用随机采样时使用的随机数生成器
*, # 确保以下参数必须作为关键字参数传递
prefetch_factor=2, # 数据加载器的预取因子,用于控制预取的数据量
persistent_workers=False, # 如果为True,则在所有数据加载完成后,工作进程不会退出
pin_memory_device='', # 指定用于pin_memory的设备,默认为CPU
)
dataset: 你想要加载

最低0.47元/天 解锁文章
1万+





