PyTorch数据加载流程解析

1. 定义最简单的Dataset
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data  # 假设data是一个列表,如[10, 20, 30, 40]
    
    def __len__(self):
        return len(self.data)  # 返回数据总量
    
    def __getitem__(self, idx):
        return self.data[idx]  # 返回单个数据样本

# 示例数据
my_data = [10, 20, 30, 40]
dataset = MyDataset(my_data)
2. 创建DataLoader
loader = DataLoader(dataset, 
                   batch_size=2,  # 每批2个样本
                   shuffle=True)  # 打乱数据顺序
3. 遍历DataLoader时的内部操作

当执行以下代码时:

for batch in loader:
    print(batch)

实际发生的步骤

  1. DataLoader自动调用dataset.__len__()获取数据总量(这里是4)
  2. 根据batch_size=2生成索引序列(如[1,3][0,2],因shuffle=True而随机)
    • 索引生成逻辑
    • PyTorch通过以下设计保证索引不重复:
      • 采样器隔离:每个epoch生成独立的随机排列。
      • 批次切割:按固定步长切分排列,避免交叉。
      • 全局控制Sampler严格管理索引分配。
  3. 对每个索引调用dataset.__getitem__(idx)
    • 第一次取idx=1idx=3 → 返回2040
    • 自动堆叠为张量tensor([20, 40])
  4. 输出结果示例:
    tensor([20, 40])  # 第一批
    tensor([10, 30])  # 第二批
    
4. 关键点图解
数据集: [10, 20, 30, 40]
           │   │   │   │
索引:     0   1   2   3

DataLoader操作:
1. 随机选索引(如[1,3]) → 取数据2040 → 堆叠为tensor([20, 40])
2. 随机选索引(如[0,2]) → 取数据1030 → 堆叠为tensor([10, 30])
5. 如果数据是元组

假设每个样本是(用户ID, 物品ID)

class PairDataset(Dataset):
    def __init__(self):
        self.pairs = [(1,101), (2,102), (3,103)]  # (用户, 物品)
    
    def __len__(self):
        return len(self.pairs)  # 必须实现:返回数据总量

    def __getitem__(self, idx):
        return self.pairs[idx]  # 返回一个元组

loader = DataLoader(PairDataset(), batch_size=2)
for batch in loader:
    print(batch)

输出:

# 每个元组字段自动堆叠
[tensor([1, 2]), tensor([101, 102])]  # 第一批
[tensor([3]), tensor([103])]          # 第二批(最后不足batch_size)

总结

  1. Dataset:定义数据存储和单个样本获取方式(必须实现__len____getitem__
  2. DataLoader
    • 根据batch_size生成索引
    • 自动调用__getitem__获取数据
    • 将样本堆叠成批次张量
  3. 核心特性
    • 支持多进程加速(num_workers参数)
    • 自动打乱数据(shuffle=True
    • 灵活处理各种数据结构(标量、元组、字典等)

这就是PyTorch数据加载的核心机制!其他复杂功能都是基于这个简单流程的扩展。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值