这个问题是我在做模型剪枝的时候遇到的
- 先载入VGG16模型,更改全连接层,使输出单元个数与自己待处理的数据类别个数一致。
- 冻住卷积层,只训练全连接层。学习率设为0.0001,momentum=0.9。
- 当设定的epoch数目满足的时候,把模型保存,使用命令:
torch.save(model, "./cifar-10/5epochs_cifar10_vgg_model")
4 . 然后在ipython使用以下命令加载模型:
model = torch.load("/home/smiles/tsq/PyTorch/pytorch-pruning/cifar-10/5epochs_cifar10_vgg_model").cuda()
这个时候就出现错误:
<module '__main__' (built-in)> is a built-in class