Dataset
Pytorch中数据集被抽象为一个抽象类torch.utils.data.Dataset
,所有的数据集都应该继承这个类,并override以下两项:
__len__
:代表样本数量。len(obj)
等价于obj.__len__()
。__getitem__
:返回一条数据或一个样本。obj[index]
等价于obj.__getitem__
。建议将节奏的图片等高负载的操作放到这里,因为多进程时会并行调用这个函数,这样做可以加速。
dataset中应尽量只包含只读对象,避免修改任何可变对象。因为如果使用多进程,可变对象要加锁,但后面讲到的dataloader的设计使其难以加锁。如下面例子中的self.num
可能在多进程下出问题:
class BadDataset(Dataset):
def __init__(self):