pytorch对数据集进行重新采样

部署运行你感兴趣的模型镜像

背景
当不同类型数据的数量差别巨大的时候,比如猫有200张训练图片,而狗有2000张,很容易出现模型只能学到狗的特征,导致准确率无法提升的情况。

这时候,一种可行的方法就是对原始数据集进行采样,从而生成猫、狗图片数量接近的新数据集。这个新数据集中可能猫、狗图片都各有500张,其中猫的图片有一部分重复的,而狗的2000张图片中有一部分没有被采样到,但是这时候新数据集的数据分布是均衡的,就可以比较好的训练了。

操作方法
我们知道pytorch训练一般都是用的DataLoader加载数据的,我们可以通过给Dataloader传入一个sampler的采样器进行采样操作。

train_loader = DataLoader( train_dataset, batch_size=256, num_workers=2, sampler=sampler)

采样器sampler有多种,大家可以根据自己需要研究一下,这里我们使用一个按权重采样的WeightedRandomSampler。其作用是:我们可以人为的给每张图片定一个被抽取到的概率,一般每一类的所有图片的概率可以一样,然后就按每个图片的这个概率对整个数据集进行重新采样。

比如:猫只有200张图片,我们设置取到每张猫的图片的概率为1/200,而狗有2000张图片,我们设置取到每张狗的概率为1/2000。这样虽然狗的图片比较多,但我们取到猫和狗的概率是一样的,只是猫会有一些重复,而狗有一些不会取到,最终形成的新数据集就平衡了。

参考 https://blog.youkuaiyun.com/tyfwin/article/details/108435756

在这里插入图片描述
注意
注意上图的replacement参数,为True表示有放回的采样,也就是我们上边说的那种采样,有部分数据重复,有部分数据没有出现;为False表示不放回的采样,即采样后的数据集跟原来一样,只是内部数据的顺序有些变化,概率大的可能会在前边,这主要作用于有序的数据。num_samples定义采样的次数,也即采样后的数据集数目,一般设为跟原来一样。

代码

参考 https://www.cnblogs.com/king-lps/p/11004653.html

# 定义每个类别采样的权重,这个只做参考,可以根据自己需要随便定义
target = train_dataset.targets
class_sample_count = np.array([len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
# 读取原始数据为datasets对象                                                                              
dataset_train = datasets.ImageFolder(traindir)                                                                                                         

# 在DataLoader的时候传入采样器即可                                                                                
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, sampler = sampler) 

您可能感兴趣的与本文相关的镜像

ACE-Step

ACE-Step

音乐合成
ACE-Step

ACE-Step是由中国团队阶跃星辰(StepFun)与ACE Studio联手打造的开源音乐生成模型。 它拥有3.5B参数量,支持快速高质量生成、强可控性和易于拓展的特点。 最厉害的是,它可以生成多种语言的歌曲,包括但不限于中文、英文、日文等19种语言

PyTorch 中,加载数据时进行重采样通常可以通过 `torch.utils.data.WeightedRandomSampler` 或 `torch.utils.data.RandomSampler` 等采样器来实现。以下是使用 `WeightedRandomSampler` 进行重采样的示例代码: ```python import torch from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler # 自定义一个简单的数据集 class SimpleDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] # 示例数据 data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] dataset = SimpleDataset(data) # 假设我们要对数据进行重采样,为每个样本分配权重 # 这里简单地给每个样本分配不同的权重 weights = [1, 2, 1, 2, 1, 2, 1, 2, 1, 2] # 创建 WeightedRandomSampler sampler = WeightedRandomSampler(weights=weights, num_samples=len(dataset), replacement=True) # 创建 DataLoader 并使用采样器 dataloader = DataLoader(dataset, batch_size=2, sampler=sampler) # 遍历数据加载器 for batch in dataloader: print(batch) ``` 在上述代码中,首先定义了一个简单的数据集 `SimpleDataset`。然后为每个样本分配了权重 `weights`,通过 `WeightedRandomSampler` 根据这些权重进行重采样。最后将采样器传递给 `DataLoader` 来加载数据。 如果想要根据类别进行重采样,例如处理类别不平衡问题,可以按照以下方式操作: ```python import torch from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler import numpy as np # 自定义数据集 class ClassDataset(Dataset): def __init__(self, features, labels): self.features = features self.labels = labels def __len__(self): return len(self.labels) def __getitem__(self, idx): return self.features[idx], self.labels[idx] # 示例数据 features = np.random.randn(100, 10) labels = np.random.randint(0, 2, 100) dataset = ClassDataset(features, labels) # 计算每个类别的样本数 class_sample_count = np.array([len(np.where(labels == t)[0]) for t in np.unique(labels)]) # 计算每个样本的权重 weight = 1. / class_sample_count samples_weight = np.array([weight[t] for t in labels]) samples_weight = torch.from_numpy(samples_weight) samples_weight = samples_weight.double() # 创建采样器 sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) # 创建 DataLoader dataloader = DataLoader(dataset, batch_size=10, sampler=sampler) # 遍历数据加载器 for batch_features, batch_labels in dataloader: print(batch_features.shape, batch_labels) ``` 这个代码示例针对类别不平衡问题,根据每个类别的样本数量计算样本权重,然后使用 `WeightedRandomSampler` 进行重采样,使得每个类别在采样中被选中的概率更加均衡。
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值