https://zhuanlan.zhihu.com/p/200876072
class DatasetSplit(Dataset):
"""An abstract Dataset class wrapped around Pytorch Dataset class.
"""
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = [int(i) for i in idxs]
def __len__(self):
return len(self.idxs)
def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return torch.tensor(image), torch.tensor(label)
PyTorch数据集划分

本文介绍了一个用于PyTorch的数据集划分类classDatasetSplit,该类继承自PyTorch的Dataset类,能够实现对已有数据集进行指定索引的子集划分。
3906

被折叠的 条评论
为什么被折叠?



