DataLoader (1)

本文介绍了PyTorch中的DataLoader和Dataset的概念及其作用。DataLoader是用于处理模型输入数据的工具,结合了Dataset和采样器,支持单线程或多线程迭代。关键参数包括epoch、iteration、batch_size、shuffle、num_workers和drop_last。文中通过示例展示了如何创建和使用DataLoader,以及如何利用Tensorboard进行可视化,其中示例数据集为CIFAR10。
部署运行你感兴趣的模型镜像

DataLoader(1)

torch.utils.data.Dataset是代表这一数据的抽象类(也就是基类)。我们可以通过继承和重写这个抽象类实现自己的数据类,只需要定义__len__和__getitem__这个两个函数。
DataLoader是Pytorch中用来处理模型输入数据的一个工具类。组合了数据集(dataset) + 采样器(sampler),并在数据集上提供单线程或多线程(num_workers )的可迭代对象。在DataLoader中有多个参数,这些参数中重要的几个参数的含义说明如下:

 epoch:所有的训练样本输入到模型中称为一个epoch; 
 iteration:一批样本输入到模型中,成为一个Iteration;
 batchszie:一批样本大小,决定一个epoch有多少个Iteration;
 迭代次数(iteration)=样本总数(epoch)/批尺寸(batchszie)
 dataset (Dataset) – 决定数据从哪读取或者从何读取;
 batch_size (python:int, optional) – 批尺寸(每次训练样本个数,默认为1)
 shuffle (bool, optional) –每一个 epoch是否为乱序 (default: False)num_workers (python:int, optional) – 是否多进程读取数据(默认为0);
 drop_last (bool, optional) – 当样本数不能被batchsize整除时,最后一批数据是否舍弃(default: False 不舍去)
 pin_memory(bool, optional) - 如果为True会将数据放置到GPU上去(默认为false) 

DataLoader 与 DataSet 详情可参考:

https://blog.youkuaiyun.com/He3he3he/article/details/105441083

一、DataLoader使用

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备测试的数据集
test_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)

test_loader = DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

# 测试的数据集中的第一张图片及target
img,target = test_data[0]
print(img.shape)  #  torch.Size([3, 32, 32])  3 通道彩色图片  长为32  宽为32
print(target)     #  target 为 3

writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
    imgs,targets = data
    print(imgs.shape)   #   torch.Size([4, 3, 32, 32])   43通道  32*32 的图片
    print(targets)      #   tensor([3, 4, 6, 5])         4张图片的target 分别是3 4 6 5
    writer.add_images("test_data",imgs,step)
    step += 1

writer.close()

通过Tensorboard观察得到一批样本的图片数 是 4 张

image-20221216225922214

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值