使用PyTorch导出JIT模型:C++ API与libtorch实战

PyTorch导出JIT模型并用C++ API libtorch调用

本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。

Step1:导出模型

首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。

# export_jit_model.py
import torch
import torchvision.models as models

model = models.resnet50(pretrained=True)
model.eval()

example_input = torch.rand(1, 3, 224, 224)

jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')

导出 JIT 模型的方式有两种:trace 和 script。

我们采用
torch.jit.trace
的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。

在我们的工程目录
demo
下运行上面的
export_jit_model.py
,会得到一个 JIT 模型件:
resnet50_jit.pth

Step 2:安装libtorch

接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:

wget https
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值