Dataloader读取dataset子数据
最近遇到了一个问题,卡了好久终于解决了,想写一篇博客来记录一下。
我希望完成的是使用dataloader从一个dataset中根据先前设定的index来读取数据(只读取这些index的数据)。
解决方案如下(thanks to ChatGPT):
import torch
from torch.utils.data import Dataset, DataLoader, Subset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
sample = self.data[index]
return sample
def __len__(self):
return len(self.data)
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 指定感兴趣的索引
indices = [0, 2, 4]
# 创建子数据集,只包含指定索引的样本
subset = Subset(dataset, indices)
# 创建DataLoader,并遍历子数据集
dataloader = DataLoader(subset, batch_size=1, shuffle=False)
for batch in dataloader:
print(batch)