使用Subset, ConcatDataset 时的注意点

在使用 pytorch lightning 过程中,

如何实现每个epcoh 会重新加载一次DataLoader ,

从而实现自己每个epoch 训练时, 可以使用不同的样本,

这里,笔者的需求是每个epoch 中加载不同的normal 的样本;

从而避免一次载入,很多个正常样本, 导致训练过程中,正常样本与异常样本导致的类别不均衡问题。

1. 将原始的整个数据集进行拆分

1.1 使用Subset

将原始的整个数据集拆分成两个子集;

    def setup(self, stage=None):
        # Assuming labels are stored or accessible
        normal_indices = []
        abnormal_indices = []

        for idx in range(len(self.train_data)):
            # Access the label for each sample
            # Modify this line according to how you retrieve the label
            label = self.train_data.get_label(idx)  # Implement get_label method if needed
            if label == NORMAL_LABEL:
                normal_indices.append(idx)
            else:
                abnormal_indices.append(idx)

        # Create Subsets
        self.normal_data = Subset(self.train_data, normal_indices)
        self.abnormal_data = Subset(self.train_data, abnormal_indices)

1.2 Subset, ConcatDataset的运作机制

  1. 在生成正常类型的子集时, 此时传入的索引sampled_normal_indices 是属于正常样本集合中的索引, 而不在是属于整个 self.train_data 中的索引, 因此才需要Subset 将sampled_normal_indices 中的索引映射到原始的 self.train_data 中的索引。

  2. 当使用ConcatDataset([normal_subset, abnormal_subset]) , 此时,每创建一个新的batch 的索引时, 这些batch 中索引的范围属于 ConcatDataset 数据集的索引,而不再是原始的self.train_data中的索引, 在ConcatDataset中, 会自动判别此时batch 中的索引是属于 normal subset 还是 abnormal subset 中的索引,然后又会依据Subset 类中属性将其映射到最初的self.train_data 中的索引。

1.3 构建新的数据集

将两个子数据集使用ConcatDataset 拼接生成新的数据集并使用
DataLoader 进行加载;

def train_dataloader(self):
    # Randomly select indices relative to the normal subset
    sampled_indices_in_subset = np.random.choice(len(self.normal_data), size=80, replace=False)
    # Map back to original indices in self.train_data
    sampled_normal_indices = [self.normal_data.indices[i] for i in sampled_indices_in_subset]
    # Create a new Subset with these original indices
    normal_subset = Subset(self.train_data, sampled_normal_indices)

    # Abnormal data remains the same
    abnormal_subset = self.abnormal_data

    # Combine the normal and abnormal subsets
    combined_dataset = ConcatDataset([normal_subset, abnormal_subset])

    # Create DataLoader
    train_loader = DataLoader(
        combined_dataset,
        batch_size=self.hparams["training"]["batch_size"],
        num_workers=self.num_workers,
        shuffle=True,
        collate_fn=self.custom_collate,
        pin_memory=True,
    )

    return train_loade

2. 不要忘记 reload_dataloader

最后不要忘记, 在pl.trainer 中 设置每个epoch 重新加载一次DataLoader(), 这样每个epoch 中训练的数据便是不同的样本构成的了。

    trainer = pl.Trainer(
        precision=config["training"]["precision"],
        max_epochs=n_epochs,  # note,  这里每个epoch 都会重新加载一次DadatLoader, 用于实现自定义的动态数据集加载过程。
        reload_dataloaders_every_n_epochs=config["training"]["reload_dataloaders_every_n_epochs"],
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值