06-DataLoader的使用

本文详细解释了PyTorch中DataLoader的使用,包括其参数含义(如dataset、batch_size、shuffle等),并通过CIFAR10数据集实例展示了如何创建DataLoader以及不同参数设置对数据加载的影响。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

DataLoader的使用

torch.utils.data.DataLoader

形象理解:

  • dataset:一副扑克

  • dataloader:抽牌方式

CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, ***, prefetch_factor=None, persistent_workers=False, pin_memory_device='')[SOURCE]

常用参数的通俗解释

  • dataset: 自定义数据集

  • batchsize: 每次抽牌抽几张,默认为 1

  • shuffle: 每局牌局前是否洗牌(牌堆的顺序是否一样),一般设置为 True

    num_workers: 加载数据时采用的进程数量,默认为 0

    但是在windows操作系统下设置为大于0的值时可能会出现问题:

    "BrokenPipeError"

  • drop_last: 牌堆里有100张牌,每次取7张的话则到最后一定会剩余2张,设置为True则为舍弃之

测试代码

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
​
# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=False
                                         , transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True,
                         num_workers=0, drop_last=False)
​
# 按住ctrl点击CIFAR10查看源文件中的getitem()方法,发现返回值类型为:img, target
# 测试数据集中第一个样本图片及其对应的target
img, target = test_data[0]
print(img.shape)  # torch.Size([3, 32, 32])
print(target)  # 3
​
# 测试加载之后的数据集
# 对比之前的输出值即可理解之
for data in test_loader:
    imgs, targets = data
    print(imgs.shape)  # torch.Size([4, 3, 32, 32])
    # 注意此处输出的第一个样本target值为 5 ,不同于前面的 3
    # 这是因为dataloader中的 sampler 为torch.utils.data.sampler.RandomSampler
    # 说明每次从牌堆中取出的 4 张牌是随机取的!
    print(targets)  # tensor([5, 4, 3, 7])
​
​
test_loader2 = DataLoader(dataset=test_data, batch_size=64, shuffle=True,
                         num_workers=0, drop_last=False)
​
writer = SummaryWriter("logs")
step = 0
for epoch in range(2):
    # shuffle为False,则 2 轮牌堆中的牌顺序是相同的
    for imgs, targets in test_loader2:
        writer.add_images(f"Epoch:{epoch}", imgs, step)
        step += 1
​
writer.close()

lidar_file path: /root/autodl-tmp/project/data/KITTI/object/testing/velodyne/000204.bin lidar_file path: /root/autodl-tmp/project/data/KITTI/object/testing/velodyne/000205.bin lidar_file path: /root/autodl-tmp/project/data/KITTI/object/testing/velodyne/000206.bin lidar_file path: /root/autodl-tmp/project/data/KITTI/object/testing/velodyne/000207.bin eval: 39%|█████████████████████████████▍ | 44/112 [00:06<00:07, 8.56it/s, mode=TEST, recall=0/0, rpn_iou=0]Traceback (most recent call last): File "eval_rcnn.py", line 908, in <module> eval_single_ckpt(root_result_dir) File "eval_rcnn.py", line 771, in eval_single_ckpt eval_one_epoch(model, test_loader, epoch_id, root_result_dir, logger) File "eval_rcnn.py", line 694, in eval_one_epoch ret_dict = eval_one_epoch_rpn(model, dataloader, epoch_id, result_dir, logger) File "eval_rcnn.py", line 143, in eval_one_epoch_rpn for data in dataloader: File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 435, in __next__ lidar_file path: /root/autodl-tmp/project/data/KITTI/object/testing/velodyne/000208.bin data = self._next_data() File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data return self._process_data(data) File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data data.reraise() File "/root/miniconda3/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise raise self.exc_type(msg) AssertionError: Caught AssertionError in DataLoader worker process 0.
06-07
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值