Q1: 请问想加载PyTorch预训练好的模型用于MindSpore模型finetune有什么方法?
A1: 需要把PyTorch和MindSpore的参数进行一一对应,因为网络定义的灵活性,所以没办法提供统一的转化脚本。
一般情况下,CheckPoint文件中保存的就是参数名和参数值,调用相应框架的读取接口后,获取到参数名和数值后,按照MindSpore格式,构建出对象,就可以直接调用MindSpore接口保存成MindSpore格式的CheckPoint文件了。
其中主要的工作量为对比不同框架间的parameter名称,做到两个框架的网络中所有parameter name一一对应(可以使用一个map进行映射),下面代码的逻辑转化parameter格式,不包括对应parameter name。
import torchimport mindspore as ms
def pytorch2mindspore(default_file = 'torch_resnet.pth'):
# read pth file
par_dict = torch.load(default_file)['state_dict']
params_list = []
for name in par_dict:
param_dict = {}
parameter = par_dict[name]
param_dict['name'] = name
param_dict['data'] = ms.Tensor(parameter.numpy())
params_list.append(param_dict)
ms.save_checkpoint(params_list, 'ms_resnet.ckpt')
Q2