目的
将pytorch模型转化成torchscript目的就是为了可以在c++环境中调用pytorch模型。
pytorch官方链接
方法
共有两种方法将pytorch模型转成torch script ,一种是trace,另一种是script。一版在模型内部没有控制流存在的话(if,for循环),直接用trace方法就可以了。如果模型内部存在控制流,那就需要用到script方法了。
trace
通过使用示例输入对模型的结构进行一次评估,并记录这些输入在模型中的变化过程,从而捕获模型的结构。
class MyModule(nn.Module):
def __init__(self):
super(MyModule,self).__init__()
self.conv1 = nn.Conv2d(1,3,3