PyTorch如何同时让两个DataLoader打乱的顺序相同?
在深度学习领域,PyTorch因其灵活易用的特性而广受欢迎。然而,当我们处理复杂的任务时,往往会遇到一些棘手的问题。例如,在训练过程中,我们可能需要同时使用两个DataLoader,并且希望它们在每个epoch中打乱数据的顺序完全一致。这听起来似乎很简单,但实现起来却并不容易。本文将深入探讨如何在PyTorch中实现这一目标,并分享一些实用的技巧和代码示例。
为什么需要同步打乱两个DataLoader?
在许多场景下,我们需要同时使用两个或多个DataLoader。例如,一个DataLoader用于加载图像数据,另一个DataLoader用于加载对应的标签数据。如果这两个DataLoader的打乱顺序不一致,那么在训练过程中,模型可能会接收到错误的输入和标签对,导致训练效果不佳甚至完全失败。
另一个常见的应用场景是在生成对抗网络(GAN)中,生成器和判别器通常需要从同一个数据集中采样,但采样的顺序需要保持一致,以确保生成器和判别器之间的同步训练。
如何实现同步打乱?
方法一:自定义Sampler
PyTorch提供了一个灵活的Sampler
类,允许我们自定义数据的采样方式。通过实现一个自定义的Sampler
,我们可以确保两个DataLoader在每个epoch中使用相同的随机顺序。
import torch
from torch.utils.data import DataLoader, Dataset, Sampler
class CustomSampler(Sampler):
def __init__(self, data_source, seed=None):
self.data_source = data_source
self.seed = seed
self.indices = list(range(len(data_source)))
def __iter__(self):
if self.seed is not None:
generator = torch.Generator()
generator.manual_seed(self.seed)
indices = torch.randperm(len(self.data_source), generator=generator).tolist()
else:
indices = torch.randperm(len(self.data_source)).tolist()
return iter(indices)
def __len__(self):
return len(self.data_source)
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
# 示例数据
data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))
dataset = MyDataset(data, labels)
sampler = CustomSampler(dataset, seed=42)
dataloader1 = DataLoader(dataset, batch_size=16