一、数据导入部分
torch.utils.data.Dataset,这是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现。
torchvision.datasets.ImageFolder接口实现数据导入。torchvision.datasets.ImageFolder会返回一个列表(比如image_datasets[‘train’]或者image_datasets[‘val]),列表中的每个值都是一个tuple,每个tuple包含图像和标签信息。列表list是不能作为模型输入的,因此在PyTorch中需要用另一个类来封装list,那就是:
torch.utils.data.DataLoader,它可以将list类型的输入数据封装成Tensor数据格式,以备模型使用。这里是对图像和标签分别封装成一个Tensor。
data_transforms是一个字典。主要是进行一些图像预处理,比如resize、crop等。实现的时候采用的是torchvision.transforms模块,比如torchvision.transforms.Compose是用来管理所有transforms操作的,torchvision.transforms.RandomSizedCrop是做crop的。需要注意的是对于torchvision.transforms.RandomSizedCrop和**transforms.RandomHorizontalFlip()**等,输入对象都是PIL Image,也就是用python的PIL库读进来的图像内容,而transforms.Normalize([0.5, 0.5, 0.4], [0.2, 0.2, 0.5])的作用对象需要是一个Tensor,因此在transforms.Normalize([0.5, 0.5, 0.4], [0.2, 0.2, 0.5])之前有一个
**transforms.ToTensor()**就是用来生成Tensor的。另外tr
pytorch tips
最新推荐文章于 2024-06-30 19:09:04 发布
本文介绍了PyTorch中数据导入的步骤,包括使用torch.utils.data.Dataset和DataLoader,以及数据预处理的transform操作。接着讲解了模块导入,如torchvision.models、torch.nn和torch.optim。在训练部分,详细阐述了训练过程,包括更新学习率、前向传播、计算损失、反向传播和参数更新。文章最后提到了一些训练时的注意事项,如避免多线程数据加载和正确处理损失值。

最低0.47元/天 解锁文章
2690

被折叠的 条评论
为什么被折叠?



