同步知乎页:https://zhuanlan.zhihu.com/p/66707105
该过程有两步:
第一:在python中导出pytorch模型
第二:在c++中加载并执行
第一步:convert Pytorch func/model to Torch Script
有2中方法可以完成该任务。
一、 torch.jit.trace
-
function
def foo(x):
return torch.sigmoid(x)
# trace func
tmp_in = torch.rand(5)
script_func = torch.jit.trace(foo, tmp_in)
# use traced func
x = torch.rand(3)
out = script_func(x)
-
net module
class Mymodel(torch.nn.Module):
def __init__(self):
super(Mymodel, self).__init__()
self.conv=torch.nn.Conv2d(3,2,2)
def forward(self, x):
out = self.conv(x)
retu