我们通常会用到迁移学习,即在一个比较通用的pretext-task上做预训练,随后针对不同的downstream task进行微调。而在微调的时候,网络结构的最后几层通常是要做出改变的。举个例子,假设pretext-task是在imagenet上面做图像分类,而下游任务是做语义分割,那么在微调的时候需要将分类网络的最后几层全连接层去掉,改造成FCN的网络结构。此时就需要我们把前面层的权重加载进去。
如果改了模型结构以后,再简单粗暴的使用torch.load_state_dict(torch.load(‘xxx.pth’))那么肯定就会报错。所以具体怎么办呢,且耐心往下看。
首先我们定义一个简单的图像分类模型:
class model1(nn.Module):
def __init__(self, img_size):
super(model, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
self.conv2 = nn.Conv2d(16, 64, 3, 1, 1)
self.fc1 = nn.Linear(self.num_feature_pixel(img_size), 1024)
self.fc2 = nn.Linear(1024, 10)
def