import torch
from * import mymodel
from data import cfg_mnet, cfg_re50
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
print('Missing keys:{}'.format(len(missing_keys)))
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
print('Used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True
def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
print('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def load_model(model, pretrained_path, load_to_cpu):
print('Loading pretrained model from {}'.format(pretrained_path))
if load_to_cpu:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
else:
device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model
if __name__ == '__main__':
model_path = './*.pth'
lib_path = './*.pt'
torch.set_grad_enabled(False)
cfg = cfg_mnet
# net and model
net = mymodel(cfg=cfg, phase = 'test')
net = load_model(net, model_path, True)
net.eval()
x = torch.rand(1, 3, 640, 640)
output1 = net(torch.ones(1, 3, 640, 640))
model = torch.jit.trace(net, x)
output2 = model(torch.ones(1, 3, 640, 640))
#对比模型输出结果
print('output1',output1)
print('output2',output2)
torch.jit.save(model, "resnet18.pt")
print("->>模型转换成功!")
PNNX Pytorch torch.jit.trace( ) 用法(pth转pt模型)
于 2022-01-24 17:45:24 首次发布