一、数据导入部分
torch.utils.data.Dataset,这是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现。
torchvision.datasets.ImageFolder接口实现数据导入。torchvision.datasets.ImageFolder会返回一个列表(比如image_datasets[‘train’]或者image_datasets[‘val]),列表中的每个值都是一个tuple,每个tuple包含图像和标签信息。列表list是不能作为模型输入的,因此在PyTorch中需要用另一个类来封装list,那就是:
torch.utils.data.DataLoader,它可以将list类型的输入数据封装成Tensor数据格式,以备模型使用。这里是对图像和标签分别封装成一个Tensor。
data_transforms是一个字典。主要是进行一些图像预处理,比如resize、crop等。实现的时候采用的是torchvision.transforms模块,比如torchvision.transforms.Compose是用来管理所有transforms操作的,torchvision.transforms.RandomSizedCrop是做crop的。需要注意的是对于torchvision.transforms.RandomSizedCrop和**transforms.RandomHorizontalFlip()**等,输入对象都是PIL Image,也就是用python的PIL库读进来的图像内容,而transforms.Normalize([0.5, 0.5, 0.4], [0.2, 0.2, 0.5])的作用对象需要是一个Tensor,因此在transforms.Normalize([0.5, 0.5, 0.4], [0.2, 0.2, 0.5])之前有一个
**transforms.ToTensor()**就是用来生成Tensor的。另外tr