model1 = VGG('vgg16')# 实例化自己的模型; model1 = torch.nn.DataParallel(model1) model1_weights = torch.load('vgg16_eig_pretrained.pt') model1.load_state_dict({k.replace(".basic_model", ""):v for k, v in torch.load('vgg16_eig_pretrained.pt').items()})
model1 = VGG('vgg16')# 实例化自己的模型; model1 = torch.nn.DataParallel(model1) model1_weights = torch.load('vgg16_eig_pretrained.pt') model1.load_state_dict({k.replace(".basic_model", ""):v for k, v in torch.load('vgg16_eig_pretrained.pt').items()})