PyTorch中的torchscript模型转换为普通的nn.Module模型
from models import *
def ptl2pt():
map_location = 'cpu' #如果要在GPU上运行,则此处设置为 'cuda'
model = torch.jit.load("test.ptl",map_location = map_location)
new_model = Net(num_feats=32, depth_chanels=1, color_channel=3, kernel_size=3).to(device)
state_dict = model.state_dict() #从torchscript 模型中提取参数
new_model.load_state_dict(state_dict) # 将参数加载到新模型中
torch.save(new_model.state_dict(),"test.pth")# 保存新模型