pytorch使用概括

深度学习框架训练模型时的代码主要包含数据读取、网络构建和其它设置三个方面 。

1、 数据读取

具体步骤主要包含: 1. 定义torch.utils.data.Dataset数据接口  2. 使用torch.utils.data.DataLoader将数据接口封装成数据迭代器。

数据读取部分包含如何将你的数据转换成PyTorch框架的Tensor数据类型。当使用PyTorch构建好网络模型之后,将数据读取进去的接口全部与torch.utils.data有关,而在其接口下主要分为了datasetdataloadersamplerdistributed脚本torch.utils.data建立了模型与数据之间的桥梁。 除此之外,在torchvision中的datasets就是使用继承了torch.utils.data.Dataset基类。而torchvison.trainsforms则包含一些图片的预处理函数。综上所述,PyTorch的与数据相关的接口函数库可以总结为:

torch.utils.data

  • dataset
  • dataloader
  • sampler
  • distributed

2、 网络构建

如何自定义网络结构?在PyTorch中,构建网络的类都是基于torch.nn.Modele这个基类进行的。也就是说所有的网络结构的构建都可以通过进程该类来实现,包括torchvision.models接口中的模型实现类也是继承这个基类进行重写的。自定义网络结构可以参考models中的模型。

PyTorch框架中提供了一些方便使用的网络结构和预训练模型:torchvision.models 该接口可以直接导入指定的网络结构,并且可以选择是否用预训练模型初始化导入的网络结构。

如果要用某预训练模型为自定义的网络结构进行参数初始化,可用torch.load接口导入预训练模型,然后调用自定义的网络结构对象的load_state_dict方式进行参数初始化。

与网络构建相关库可以总结为:

  • torch.nn
    • modules文件夹
    • backends文件夹
    • _functions文件夹
    • parallel文件夹
    • utils文件夹
    • cpp
    • functional
    • init
  • torchvision
    • models

3 、网络设置

损失函数通过torch.nn包实现,比如torch.nn.CrossEntropyLoss()接口表示交叉熵等。

优化函数通过torch.optim包实现,比如torch.optim.SGD()接口表示随机梯度下降。更多优化函数可以查看PyTorch官方文档。

学习率策略通过torch.optim.lr_scheduler接口实现,比如torch.optim.lr_scheduler.StepLR()接口表示按指定epoch数减少学习率。更多的学习率变化策略可以看官方文档。

多GPU训练通过torch.nn.DataParallel接口实现,比如:

model = torch.nn.DataParallel(model, device_ids=[0, 1])

表示在gpu0和1上训练模型。

torch.optim

  • adadelta
  • adagrad
  • adam
  • adamax
  • asgd
  • lbfgs
  • lr_scheduler
  • optimizer
  • rmsprop
  • rprop
  • sgd
  • sparse_adam

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值