这里写自定义目录标题
1. 手写 Dataset
类
假设我们有一个自定义的图像数据集,存储在一个文件夹中,文件夹结构如下:
data/
class_1/
img1.jpg
img2.jpg
...
class_2/
img1.jpg
img2.jpg
...
我们需要实现一个 Dataset
类来加载这些数据。
实现步骤:
- 继承
torch.utils.data.Dataset
。 - 实现
__len__
方法,返回数据集的大小。 - 实现
__getitem__
方法,根据索引返回一个样本(图像和标签)。
代码实现:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
参数:
root_dir (str): 数据集的根目录。
transform (callable, optional): 可选的数据预处理操作。
"""
self.root_dir = root_dir
self.transform = transform
self.classes = os.listdir(root_dir) # 获取所有类别
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} # 类别到索引的映射
self.samples = self._load_samples() # 加载所有样本
def _load_samples(self):
"""加载所有样本的路径和标签"""
samples = []
for cls in self.classes:
cls_dir = os.path.join(self.root_dir, cls)
for img_name in os.listdir(cls_dir):
img_path = os.path.join(cls_dir, img_name)
samples.append((img_path, self.class_to_idx[cls]))
return samples
def __len__(self):
"""返回数据集的大小"""
return len(self.samples)
def __getitem__(self, idx):
"""根据索引返回一个样本(图像和标签)"""
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB') # 打开图像并转换为 RGB
if self.transform:
image = self.transform(image) # 数据预处理
return image, label # 返回图像和标签
2. 数据集划分
2.1 使用Subset创建子集
Subset
是 PyTorch 中的一个工具类,用于从数据集中创建一个子集。它允许你从原始数据集中选择特定的样本,而不需要复制数据。这在处理大型数据集时非常有用,因为你可能只需要其中的一部分数据进行训练或测试。
Subset 类的定义
Subset
类是 PyTorch 的 torch.utils.data
模块中的一个类,其定义如下:
class Subset(Dataset):
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
参数说明
dataset
: 原始数据集,通常是torch.utils.data.Dataset
的一个实例。indices
: 一个列表或数组,包含你希望从原始数据集中选择的样本的索引。
关键方法
-
__getitem__(self, idx)
:- 返回子集中第
idx
个样本。 - 实际返回的是原始数据集中
self.indices[idx]
位置的样本。
- 返回子集中第
-
__len__(self)
:- 返回子集的大小,即
indices
的长度。
- 返回子集的大小,即
示例代码
假设你有一个数据集 train_dataset
,并且你只想使用其中的前 6 个样本(索引为 0 到 5)来创建一个子集:
from torch.utils.data import Subset
# 假设 train_dataset 是一个 Dataset 对象
train_subset = Subset(train_dataset, [0, 1, 2, 3, 4, 5])
# 现在 train_subset 包含了 train_dataset 中索引为 0 到 5 的样本
使用场景
- 调试: 当你需要快速测试代码时,可以使用子集来减少数据量。
- 实验: 当你只想使用数据集的一部分进行实验时。
- 数据分割: 在划分训练集、验证集和测试集时,可以使用
Subset
来创建这些子集。
注意事项
Subset
并不会复制数据,它只是提供了一个视图(view),因此对原始数据集的修改会反映在子集中。- 如果你需要独立的数据子集,可以考虑使用
torch.utils.data.random_split
或其他方法。
2.2 使用 random_split
将数据集划分为训练集、验证集和测试集。
from torch.utils.data import random_split
# 创建自定义数据集
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = CustomImageDataset(root_dir='data', transform=transform)
# 划分数据集
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
torch.utils.data.random_split
是 PyTorch 中一个非常实用的函数,用于将数据集随机划分为多个子集。它通常用于将数据集划分为训练集、验证集和测试集。这个函数的好处是它不会复制数据,而是通过索引来创建子集,因此非常高效。
函数定义
torch.utils.data.random_split(dataset, lengths, generator=torch.Generator())
参数说明:
-
dataset
:- 要划分的原始数据集,通常是
torch.utils.data.Dataset
的一个实例。
- 要划分的原始数据集,通常是
-
lengths
:- 一个列表,包含每个子集的长度。列表的长度决定了划分的子集数量。
- 例如,
lengths=[100, 50]
会将数据集划分为两个子集,第一个子集包含 100 个样本,第二个子集包含 50 个样本。
-
generator
(可选):- 用于控制随机数生成的
torch.Generator
对象。 - 如果指定了
generator
,可以确保每次运行时的随机划分结果是一致的(通过设置随机种子)。
- 用于控制随机数生成的
返回值:
- 返回一个包含
Subset
对象的列表,每个Subset
对应一个划分的子集。
1. 基本用法
假设你有一个包含 150 个样本的数据集,你想将其划分为训练集(100 个样本)和验证集(50 个样本):
from torch.utils.data import random_split
# 假设 dataset 是一个 Dataset 对象,包含 150 个样本
train_size = 100
val_size = 50
# 划分数据集
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 现在 train_dataset 包含 100 个样本,val_dataset 包含 50 个样本
print(len(train_dataset)) # 输出: 100
print(len(val_dataset)) # 输出: 50
2. 设置随机种子
如果你希望每次运行时的划分结果一致,可以通过 generator
参数设置随机种子:
import torch
# 设置随机种子
generator = torch.Generator().manual_seed(42)
# 划分数据集
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
3. 划分为多个子集
你也可以将数据集划分为多个子集。例如,划分为训练集、验证集和测试集:
train_size = 100
val_size = 30
test_size = 20
# 划分数据集
train_dataset, val_dataset, test_dataset = random_split(
dataset, [train_size, val_size, test_size]
)
print(len(train_dataset)) # 输出: 100
print(len(val_dataset)) # 输出: 30
print(len(test_dataset)) # 输出: 20
注意事项
-
lengths
的总和必须等于数据集的长度:- 如果
lengths
的总和小于数据集的长度,会抛出错误。 - 如果
lengths
的总和大于数据集的长度,也会抛出错误。
- 如果
-
不会复制数据:
random_split
返回的是Subset
对象,它们只是对原始数据集的引用,因此不会占用额外的内存。
-
随机性:
- 默认情况下,每次调用
random_split
时,划分结果是随机的。 - 如果需要可重复的结果,可以通过
generator
参数设置随机种子。
- 默认情况下,每次调用
与 Subset
的区别
Subset
是通过指定索引来创建子集,而random_split
是随机划分数据集。random_split
是基于Subset
实现的,但它提供了更方便的随机划分功能。
3. 数据集加载
使用 DataLoader
加载数据集,支持批量加载和多线程。
from torch.utils.data import DataLoader
# 创建 DataLoader
batch_size = 32
num_workers = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
DataLoader
是 PyTorch 中用于批量加载数据的工具,它可以将数据集包装成一个可迭代对象,方便在训练和测试过程中高效地加载数据。DataLoader
提供了许多实用的功能,例如批量加载、多线程数据加载、数据打乱等。
下面我们将详细介绍 DataLoader
的使用方法,并结合手写的 Dataset
类,展示如何在实际项目中应用 DataLoader
。
3.1 DataLoader
的功能
DataLoader
的主要功能包括:
- 批量加载:将数据集分成多个批次(batch),每次返回一个批次的数据。
- 数据打乱:在每个 epoch 开始时打乱数据顺序,避免模型过拟合。
- 多线程加载:使用多个子进程并行加载数据,加速数据加载过程。
- 自定义采样:支持自定义采样策略(如随机采样、加权采样等)。
3.2 DataLoader
的参数
DataLoader
的常用参数如下:
参数名 | 说明 |
---|---|
dataset | 要加载的数据集,必须是 torch.utils.data.Dataset 的实例。 |
batch_size | 每个批次的样本数量,默认为 1。 |
shuffle | 是否在每个 epoch 开始时打乱数据顺序,默认为 False 。 |
num_workers | 用于数据加载的子进程数量,默认为 0(主进程加载)。 |
drop_last | 如果数据集大小不能被 batch_size 整除,是否丢弃最后一个不完整的批次。 |
sampler | 自定义采样器,用于控制数据加载的顺序。 |
batch_sampler | 自定义批次采样器,用于控制批次的生成方式。 |
collate_fn | 自定义函数,用于将多个样本合并为一个批次。 |
3.3 使用 DataLoader
加载数据
假设我们已经手写了一个 Dataset
类(如前面的 CustomImageDataset
),接下来可以使用 DataLoader
加载数据。
示例代码:
from torch.utils.data import DataLoader
# 创建 DataLoader
batch_size = 32
num_workers = 4 # 使用 4 个子进程加载数据
shuffle = True # 打乱数据顺序
train_loader = DataLoader(
train_dataset, # 训练集
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers
)
val_loader = DataLoader(
val_dataset, # 验证集
batch_size=batch_size,
shuffle=False, # 验证集不需要打乱
num_workers=num_workers
)
test_loader = DataLoader(
test_dataset, # 测试集
batch_size=batch_size,
shuffle=False, # 测试集不需要打乱
num_workers=num_workers
)
4. 数据迭代
在训练或测试过程中,逐批次加载数据。
for batch_idx, (data, target) in enumerate(train_loader):
# data: 输入数据 (batch_size, channels, height, width)
# target: 标签数据 (batch_size,)
print(f"Batch {batch_idx}, Data shape: {data.shape}, Target shape: {target.shape}")
# 在这里将数据送入模型进行训练
# model.train()
# output = model(data)
# loss = criterion(output, target)
# ...
5. 数据可视化
可视化数据集中的样本,检查数据是否正确加载和预处理。
import matplotlib.pyplot as plt
import torchvision.utils as vutils
# 获取一个批次的数据
data, target = next(iter(train_loader))
# 将张量转换为图像
img_grid = vutils.make_grid(data, nrow=8, normalize=True, pad_value=0.5)
# 显示图像
plt.figure(figsize=(10, 10))
plt.imshow(img_grid.permute(1, 2, 0)) # 调整维度顺序为 (H, W, C)
plt.axis('off')
plt.title("Training Images")
plt.show()
6. 完整代码
以下是完整的代码示例:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
import torchvision.utils as vutils
# 1. 手写 Dataset 类
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = os.listdir(root_dir)
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
self.samples = self._load_samples()
def _load_samples(self):
samples = []
for cls in self.classes:
cls_dir = os.path.join(self.root_dir, cls)
for img_name in os.listdir(cls_dir):
img_path = os.path.join(cls_dir, img_name)
samples.append((img_path, self.class_to_idx[cls]))
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
# 2. 数据预处理
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 3. 加载数据集
dataset = CustomImageDataset(root_dir='data', transform=transform)
# 4. 数据集划分
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
# 5. 创建 DataLoader
batch_size = 32
num_workers = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# 6. 数据迭代示例
for batch_idx, (data, target) in enumerate(train_loader):
print(f"Batch {batch_idx}, Data shape: {data.shape}, Target shape: {target.shape}")
# 7. 数据可视化
data, target = next(iter(train_loader))
img_grid = vutils.make_grid(data, nrow=8, normalize=True, pad_value=0.5)
plt.figure(figsize=(10, 10))
plt.imshow(img_grid.permute(1, 2, 0))
plt.axis('off')
plt.title("Training Images")
plt.show()
7. 总结
通过手写 Dataset
类,你可以灵活地加载自定义数据集,并结合 PyTorch 的 DataLoader
实现高效的数据加载和迭代。这种方法适用于各种类型的数据(如图像、文本、音频等),只需根据数据的特点调整 __getitem__
方法的实现即可。