总结:
如果要实现动态变化的dataloader,必须重写dataset的len方法
import torch
from torch.utils.data import DataLoader, TensorDataset
随机生成 100 个样本,每个样本 10 个特征
data = torch.randn(10, 1) # 100个样本,每个样本10个特征
随机生成 100 个标签(二分类标签,0 或 1)
targets = torch.randint(0, 2, (10,)) # 100个标签,取值为0或1
创建 TensorDataset
class MyDataset(TensorDataset):
def init(self, data, targets):
super(MyDataset, self).init(data, targets)
self.data = data
self.targets = targets
def len(self):
return len(self.data)
def getitem(self, index):
# data = super(MyDataset, self).getitem(index)
return self.data[index], self.targets[index]
def aug_data(self,extra_data,extra_targets):
self.data=torch.cat((self.data,extra_data),0)
self.targets=torch.cat((self.targets,extra_tar