在使用 pytorch lightning 过程中,
如何实现每个epcoh 会重新加载一次DataLoader ,
从而实现自己每个epoch 训练时, 可以使用不同的样本,
这里,笔者的需求是每个epoch 中加载不同的normal 的样本;
从而避免一次载入,很多个正常样本, 导致训练过程中,正常样本与异常样本导致的类别不均衡问题。
1. 将原始的整个数据集进行拆分
1.1 使用Subset
将原始的整个数据集拆分成两个子集;
def setup(self, stage=None):
# Assuming labels are stored or accessible
normal_indices = []
abnormal_indices = []
for idx in range(len(self.train_data)):
# Access the label for each sample
# Modify this line according to how you retrieve the label
label = self.train_data.get_label(idx) # Implement get_label method if needed
if label == NORMAL_LABEL:
normal_indices.append(idx)
else:
abnormal_indices.append(idx)
# Create Subsets
self.normal_data = Subset(self.train_data, normal_indices)
self.abnormal_data = Subset(self.train_data, abnormal_indices)
1.2 Subset, ConcatDataset的运作机制
-
在生成正常类型的子集时, 此时传入的索引sampled_normal_indices 是属于正常样本集合中的索引, 而不在是属于整个 self.train_data 中的索引, 因此才需要Subset 将sampled_normal_indices 中的索引映射到原始的 self.train_data 中的索引。
-
当使用ConcatDataset([normal_subset, abnormal_subset]) , 此时,每创建一个新的batch 的索引时, 这些batch 中索引的范围属于 ConcatDataset 数据集的索引,而不再是原始的self.train_data中的索引, 在ConcatDataset中, 会自动判别此时batch 中的索引是属于 normal subset 还是 abnormal subset 中的索引,然后又会依据Subset 类中属性将其映射到最初的self.train_data 中的索引。
1.3 构建新的数据集
将两个子数据集使用ConcatDataset 拼接生成新的数据集并使用
DataLoader 进行加载;
def train_dataloader(self):
# Randomly select indices relative to the normal subset
sampled_indices_in_subset = np.random.choice(len(self.normal_data), size=80, replace=False)
# Map back to original indices in self.train_data
sampled_normal_indices = [self.normal_data.indices[i] for i in sampled_indices_in_subset]
# Create a new Subset with these original indices
normal_subset = Subset(self.train_data, sampled_normal_indices)
# Abnormal data remains the same
abnormal_subset = self.abnormal_data
# Combine the normal and abnormal subsets
combined_dataset = ConcatDataset([normal_subset, abnormal_subset])
# Create DataLoader
train_loader = DataLoader(
combined_dataset,
batch_size=self.hparams["training"]["batch_size"],
num_workers=self.num_workers,
shuffle=True,
collate_fn=self.custom_collate,
pin_memory=True,
)
return train_loade
2. 不要忘记 reload_dataloader
最后不要忘记, 在pl.trainer 中 设置每个epoch 重新加载一次DataLoader(), 这样每个epoch 中训练的数据便是不同的样本构成的了。
trainer = pl.Trainer(
precision=config["training"]["precision"],
max_epochs=n_epochs, # note, 这里每个epoch 都会重新加载一次DadatLoader, 用于实现自定义的动态数据集加载过程。
reload_dataloaders_every_n_epochs=config["training"]["reload_dataloaders_every_n_epochs"],