DataLoader 是 PyTorch 中用于批量加载数据的核心工具,它能够高效地处理大规模数据集,并支持多线程数据加载、数据打乱、批处理等功能。
一、基本概念与核心参数
DataLoader 通常与 Dataset 类配合使用:
- Dataset:负责存储样本和获取单个数据样本(如图片、标签)。
- DataLoader:负责将 Dataset 中的数据打包成批次,并进行并行加载
核心参数
from torch.utils.data import DataLoader
DataLoader(
dataset, # 必须是 torch.utils.data.Dataset 的子类
batch_size=1, # 每批次包含的样本数
shuffle=False, # 是否在每个 epoch 开始时打乱数据
sampler=None, # 自定义采样策略
batch_sampler=None, # 自定义批次采样策略
num_workers=0, # 并行加载数据的进程数
collate_fn=None, # 自定义如何将样本合并为批次
pin_memory=False, # 是否将数据加载到 CUDA 固定内存中(加速 GPU 数据传输)
drop_last=False, # 是否丢弃最后一个不完整的批次
timeout=0, # 数据加载超时时间
worker_init_fn=None, # 每个 worker 进程的初始化函数
)
二、使用示例
1. 使用内置 Dataset(如 MNIST)
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.1307,), (0.3081,)) # 归一化处理
])
# 加载训练集和测试集
train_dataset = datasets.MNIST(
root='./data', # 数据存储路径
train=True, # 是否为训练集
download=True, # 是否下载数据
transform=transform # 应用预处理
)
test_dataset = datasets.MNIST(
root='./data',
train=False,
transform=transform
)
# 创建 DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=64, # 每批次 64 个样本
shuffle=True, # 训练集需要打乱数据
num_workers=4 # 使用 4 个进程并行加载数据
)
test_loader = DataLoader(
test_dataset,
batch_size=128,
shuffle=False, # 测试集不需要打乱数据
num_workers=2
)
# 遍历数据(典型的训练循环)
for epoch in range(10):
for batch_idx, (data, target) in enumerate(train_loader):
# data.shape: [64, 1, 28, 28](批次大小、通道数、高度、宽度)
# target.shape: [64](标签)
print(f"Epoch: {epoch}, Batch: {batch_idx}, Data shape: {data.shape}")
# 模型训练代码...
2. 自定义 Dataset
当处理自己的数据集时,需要继承 torch.utils.data.Dataset 类并实现__len__ 和 __getitem__ 方法:
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, data, labels, transform=None):
self.data = data # 数据
self.labels = labels # 标签
self.transform = transform # 预处理函数
def __len__(self):
return len(self.data) # 返回数据集大小
def __getitem__(self, idx):
x = self.data[idx]
y = self.labels[idx]
if self.transform:
x = self.transform(x) # 应用预处理
return x, y # 返回样本和标签
# 使用自定义 Dataset
data = torch.rand(1000, 3, 32, 32) # 1000 个 3x32x32 的随机图像
labels = torch.randint(0, 10, (1000,)) # 1000 个 0-9 的随机标签
dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 验证数据加载
for batch in dataloader:
x, y = batch
print(f"Input shape: {x.shape}, Labels shape: {y.shape}")
三、高级用法
1. 自定义 collate_fn(处理变长数据)
当处理序列数据(如文本)时,样本长度可能不一致,需要自定义合并方式:
# 定义collate_fn函数,用于在DataLoader中对一个batch的数据进行处理
def my_collate_fn(examples):
# 函数功能:将本批次的所有样本转换成张量并进行填充
# 输入:
# example,样本列表,每个样本不一定等长
# 输出:
# inputs:输入到模型的tensor([num_ex,max_len])(已经填充好的)
# lengths:批次每个样本长度->tensor([num_ex])
# targets:样本标签->tensor([num_ex])一维张量
#-----------------examples:是一个列表(列表可以嵌套,但最终只有一维属性长度,没有像张量那样的多维度大小)
# 在这个任务中examp列表格式:[[[样本1],标签],[[样本2],标签],……,[[样本],标签]]
# len(examples)->列表的长度(样本数)
# example[0]->第一个样本及其标签,如[[2,45,6,21,7],0]
# 获取每个样本的长度,并将其转换为张量
lengths = torch.tensor([len(ex[0]) for ex in examples]) # 最终lengths->保存每个样本的长度,并且是一个一维张量
# 将每个样本的输入部分转换为张量
inputs = [torch.tensor(ex[0]) for ex in examples]
# 将每个样本的目标部分转换为长整型张量
targets = torch.tensor([ex[1] for ex in examples], dtype=torch.long)
# 对batch内的样本进行padding,使其具有相同长度
inputs = pad_sequence(inputs, batch_first=True)
# 返回处理后的输入、长度和目标
return inputs, lengths, targets
# 在 DataLoader 中使用自定义 collate_fn
dataloader = DataLoader(dataset, batch_size=32, collate_fn=my_collate_fn)
2. 自定义采样策略
当需要非均匀采样(如类别平衡)时,可以使用自定义 Sampler:
from torch.utils.data import WeightedRandomSampler
# 假设样本权重(处理类别不平衡)
weights = torch.tensor([3.0, 1.0, 3.0, 1.0]) # 为每个样本分配权重
sampler = WeightedRandomSampler(weights, num_samples=4, replacement=True)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)
4134

被折叠的 条评论
为什么被折叠?



