深度学习--数据加载


1. 手写 Dataset

假设我们有一个自定义的图像数据集,存储在一个文件夹中,文件夹结构如下:

data/
    class_1/
        img1.jpg
        img2.jpg
        ...
    class_2/
        img1.jpg
        img2.jpg
        ...

我们需要实现一个 Dataset 类来加载这些数据。

实现步骤:

  1. 继承 torch.utils.data.Dataset
  2. 实现 __len__ 方法,返回数据集的大小。
  3. 实现 __getitem__ 方法,根据索引返回一个样本(图像和标签)。

代码实现:

import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        参数:
            root_dir (str): 数据集的根目录。
            transform (callable, optional): 可选的数据预处理操作。
        """
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)  # 获取所有类别
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}  # 类别到索引的映射
        self.samples = self._load_samples()  # 加载所有样本

    def _load_samples(self):
        """加载所有样本的路径和标签"""
        samples = []
        for cls in self.classes:
            cls_dir = os.path.join(self.root_dir, cls)
            for img_name in os.listdir(cls_dir):
                img_path = os.path.join(cls_dir, img_name)
                samples.append((img_path, self.class_to_idx[cls]))
        return samples

    def __len__(self):
        """返回数据集的大小"""
        return len(self.samples)

    def __getitem__(self, idx):
        """根据索引返回一个样本(图像和标签)"""
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')  # 打开图像并转换为 RGB

        if self.transform:
            image = self.transform(image)  # 数据预处理

        return image, label  # 返回图像和标签

2. 数据集划分

2.1 使用Subset创建子集

Subset 是 PyTorch 中的一个工具类,用于从数据集中创建一个子集。它允许你从原始数据集中选择特定的样本,而不需要复制数据。这在处理大型数据集时非常有用,因为你可能只需要其中的一部分数据进行训练或测试。

Subset 类的定义

Subset 类是 PyTorch 的 torch.utils.data 模块中的一个类,其定义如下:

class Subset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)
参数说明
  • dataset: 原始数据集,通常是 torch.utils.data.Dataset 的一个实例。
  • indices: 一个列表或数组,包含你希望从原始数据集中选择的样本的索引。
关键方法
  1. __getitem__(self, idx):

    • 返回子集中第 idx 个样本。
    • 实际返回的是原始数据集中 self.indices[idx] 位置的样本。
  2. __len__(self):

    • 返回子集的大小,即 indices 的长度。
示例代码

假设你有一个数据集 train_dataset,并且你只想使用其中的前 6 个样本(索引为 0 到 5)来创建一个子集:

from torch.utils.data import Subset

# 假设 train_dataset 是一个 Dataset 对象
train_subset = Subset(train_dataset, [0, 1, 2, 3, 4, 5])

# 现在 train_subset 包含了 train_dataset 中索引为 0 到 5 的样本
使用场景
  • 调试: 当你需要快速测试代码时,可以使用子集来减少数据量。
  • 实验: 当你只想使用数据集的一部分进行实验时。
  • 数据分割: 在划分训练集、验证集和测试集时,可以使用 Subset 来创建这些子集。
注意事项
  • Subset 并不会复制数据,它只是提供了一个视图(view),因此对原始数据集的修改会反映在子集中。
  • 如果你需要独立的数据子集,可以考虑使用 torch.utils.data.random_split 或其他方法。

2.2 使用 random_split 将数据集划分为训练集、验证集和测试集。

from torch.utils.data import random_split

# 创建自定义数据集
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = CustomImageDataset(root_dir='data', transform=transform)

# 划分数据集
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

torch.utils.data.random_split 是 PyTorch 中一个非常实用的函数,用于将数据集随机划分为多个子集。它通常用于将数据集划分为训练集、验证集和测试集。这个函数的好处是它不会复制数据,而是通过索引来创建子集,因此非常高效。


函数定义
torch.utils.data.random_split(dataset, lengths, generator=torch.Generator())
参数说明:
  1. dataset:

    • 要划分的原始数据集,通常是 torch.utils.data.Dataset 的一个实例。
  2. lengths:

    • 一个列表,包含每个子集的长度。列表的长度决定了划分的子集数量。
    • 例如,lengths=[100, 50] 会将数据集划分为两个子集,第一个子集包含 100 个样本,第二个子集包含 50 个样本。
  3. generator (可选):

    • 用于控制随机数生成的 torch.Generator 对象。
    • 如果指定了 generator,可以确保每次运行时的随机划分结果是一致的(通过设置随机种子)。
返回值:
  • 返回一个包含 Subset 对象的列表,每个 Subset 对应一个划分的子集。

1. 基本用法

假设你有一个包含 150 个样本的数据集,你想将其划分为训练集(100 个样本)和验证集(50 个样本):

from torch.utils.data import random_split

# 假设 dataset 是一个 Dataset 对象,包含 150 个样本
train_size = 100
val_size = 50

# 划分数据集
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# 现在 train_dataset 包含 100 个样本,val_dataset 包含 50 个样本
print(len(train_dataset))  # 输出: 100
print(len(val_dataset))    # 输出: 50
2. 设置随机种子

如果你希望每次运行时的划分结果一致,可以通过 generator 参数设置随机种子:

import torch

# 设置随机种子
generator = torch.Generator().manual_seed(42)

# 划分数据集
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
3. 划分为多个子集

你也可以将数据集划分为多个子集。例如,划分为训练集、验证集和测试集:

train_size = 100
val_size = 30
test_size = 20

# 划分数据集
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

print(len(train_dataset))  # 输出: 100
print(len(val_dataset))    # 输出: 30
print(len(test_dataset))   # 输出: 20

注意事项
  1. lengths 的总和必须等于数据集的长度

    • 如果 lengths 的总和小于数据集的长度,会抛出错误。
    • 如果 lengths 的总和大于数据集的长度,也会抛出错误。
  2. 不会复制数据

    • random_split 返回的是 Subset 对象,它们只是对原始数据集的引用,因此不会占用额外的内存。
  3. 随机性

    • 默认情况下,每次调用 random_split 时,划分结果是随机的。
    • 如果需要可重复的结果,可以通过 generator 参数设置随机种子。

Subset 的区别
  • Subset 是通过指定索引来创建子集,而 random_split 是随机划分数据集。
  • random_split 是基于 Subset 实现的,但它提供了更方便的随机划分功能。

3. 数据集加载

使用 DataLoader 加载数据集,支持批量加载和多线程。

from torch.utils.data import DataLoader

# 创建 DataLoader
batch_size = 32
num_workers = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

DataLoader 是 PyTorch 中用于批量加载数据的工具,它可以将数据集包装成一个可迭代对象,方便在训练和测试过程中高效地加载数据。DataLoader 提供了许多实用的功能,例如批量加载、多线程数据加载、数据打乱等。

下面我们将详细介绍 DataLoader 的使用方法,并结合手写的 Dataset 类,展示如何在实际项目中应用 DataLoader


3.1 DataLoader 的功能

DataLoader 的主要功能包括:

  • 批量加载:将数据集分成多个批次(batch),每次返回一个批次的数据。
  • 数据打乱:在每个 epoch 开始时打乱数据顺序,避免模型过拟合。
  • 多线程加载:使用多个子进程并行加载数据,加速数据加载过程。
  • 自定义采样:支持自定义采样策略(如随机采样、加权采样等)。

3.2 DataLoader 的参数

DataLoader 的常用参数如下:

参数名说明
dataset要加载的数据集,必须是 torch.utils.data.Dataset 的实例。
batch_size每个批次的样本数量,默认为 1。
shuffle是否在每个 epoch 开始时打乱数据顺序,默认为 False
num_workers用于数据加载的子进程数量,默认为 0(主进程加载)。
drop_last如果数据集大小不能被 batch_size 整除,是否丢弃最后一个不完整的批次。
sampler自定义采样器,用于控制数据加载的顺序。
batch_sampler自定义批次采样器,用于控制批次的生成方式。
collate_fn自定义函数,用于将多个样本合并为一个批次。

3.3 使用 DataLoader 加载数据

假设我们已经手写了一个 Dataset 类(如前面的 CustomImageDataset),接下来可以使用 DataLoader 加载数据。

示例代码:

from torch.utils.data import DataLoader

# 创建 DataLoader
batch_size = 32
num_workers = 4  # 使用 4 个子进程加载数据
shuffle = True   # 打乱数据顺序

train_loader = DataLoader(
    train_dataset,  # 训练集
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=num_workers
)

val_loader = DataLoader(
    val_dataset,    # 验证集
    batch_size=batch_size,
    shuffle=False,  # 验证集不需要打乱
    num_workers=num_workers
)

test_loader = DataLoader(
    test_dataset,   # 测试集
    batch_size=batch_size,
    shuffle=False,  # 测试集不需要打乱
    num_workers=num_workers
)

4. 数据迭代

在训练或测试过程中,逐批次加载数据。

for batch_idx, (data, target) in enumerate(train_loader):
    # data: 输入数据 (batch_size, channels, height, width)
    # target: 标签数据 (batch_size,)
    print(f"Batch {batch_idx}, Data shape: {data.shape}, Target shape: {target.shape}")

    # 在这里将数据送入模型进行训练
    # model.train()
    # output = model(data)
    # loss = criterion(output, target)
    # ...

5. 数据可视化

可视化数据集中的样本,检查数据是否正确加载和预处理。

import matplotlib.pyplot as plt
import torchvision.utils as vutils

# 获取一个批次的数据
data, target = next(iter(train_loader))

# 将张量转换为图像
img_grid = vutils.make_grid(data, nrow=8, normalize=True, pad_value=0.5)

# 显示图像
plt.figure(figsize=(10, 10))
plt.imshow(img_grid.permute(1, 2, 0))  # 调整维度顺序为 (H, W, C)
plt.axis('off')
plt.title("Training Images")
plt.show()

6. 完整代码

以下是完整的代码示例:

import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# 1. 手写 Dataset 类
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.samples = self._load_samples()

    def _load_samples(self):
        samples = []
        for cls in self.classes:
            cls_dir = os.path.join(self.root_dir, cls)
            for img_name in os.listdir(cls_dir):
                img_path = os.path.join(cls_dir, img_name)
                samples.append((img_path, self.class_to_idx[cls]))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# 2. 数据预处理
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 3. 加载数据集
dataset = CustomImageDataset(root_dir='data', transform=transform)

# 4. 数据集划分
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# 5. 创建 DataLoader
batch_size = 32
num_workers = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# 6. 数据迭代示例
for batch_idx, (data, target) in enumerate(train_loader):
    print(f"Batch {batch_idx}, Data shape: {data.shape}, Target shape: {target.shape}")

# 7. 数据可视化
data, target = next(iter(train_loader))
img_grid = vutils.make_grid(data, nrow=8, normalize=True, pad_value=0.5)
plt.figure(figsize=(10, 10))
plt.imshow(img_grid.permute(1, 2, 0))
plt.axis('off')
plt.title("Training Images")
plt.show()

7. 总结

通过手写 Dataset 类,你可以灵活地加载自定义数据集,并结合 PyTorch 的 DataLoader 实现高效的数据加载和迭代。这种方法适用于各种类型的数据(如图像、文本、音频等),只需根据数据的特点调整 __getitem__ 方法的实现即可。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值