pytorch tips

本文介绍了PyTorch中数据导入的步骤,包括使用torch.utils.data.Dataset和DataLoader,以及数据预处理的transform操作。接着讲解了模块导入,如torchvision.models、torch.nn和torch.optim。在训练部分,详细阐述了训练过程,包括更新学习率、前向传播、计算损失、反向传播和参数更新。文章最后提到了一些训练时的注意事项,如避免多线程数据加载和正确处理损失值。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、数据导入部分
torch.utils.data.Dataset,这是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现。
torchvision.datasets.ImageFolder接口实现数据导入。torchvision.datasets.ImageFolder会返回一个列表(比如image_datasets[‘train’]或者image_datasets[‘val]),列表中的每个值都是一个tuple,每个tuple包含图像和标签信息。列表list是不能作为模型输入的,因此在PyTorch中需要用另一个类来封装list,那就是:
torch.utils.data.DataLoader,它可以将list类型的输入数据封装成Tensor数据格式,以备模型使用。这里是对图像和标签分别封装成一个Tensor。
data_transforms是一个字典。主要是进行一些图像预处理,比如resize、crop等。实现的时候采用的是torchvision.transforms模块,比如torchvision.transforms.Compose是用来管理所有transforms操作的,torchvision.transforms.RandomSizedCrop是做crop的。需要注意的是对于torchvision.transforms.RandomSizedCrop和**transforms.RandomHorizontalFlip()**等,输入对象都是PIL Image,也就是用python的PIL库读进来的图像内容,而transforms.Normalize([0.5, 0.5, 0.4], [0.2, 0.2, 0.5])的作用对象需要是一个Tensor,因此在transforms.Normalize([0.5, 0.5, 0.4], [0.2, 0.2, 0.5])之前有一个
**transforms.ToTensor()**就是用来生成Tensor的。另外tr

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值