核心:模型序列字典中键名的匹配
1 读取mxnet模型
import mxnet as mx
import torch
from resnet50_gcn import resnet50
from torch.nn import BatchNorm2d
from torch.nn import Conv2d
from torch.nn import Linear
import pandas as pd
import numpy as np
def get_model(model_path, epoch):
sym, arg_params, aux_params = mx.model.load_checkpoint(model_path, epoch)
return sym, arg_params,aux_params
sym, arg_params, aux_params = get_model('resnet50_w-glore_0-3_', 0)
2 读取pytorch模型
pytorch_model = resnet50(pretrained=False)
pytorch_model = init_model(pytorch_model, new_arg_params)
3 键名匹配
def dict_change_keys(input_dict, change_map, new_dict):
old_dict_key = []
counts_num = 0
new_dict_key = pd.read_csv(change_map, header=None, sep=' ')
for key in