有时候任务需要,想从一个训练好的网络里提取部分网络和参数做为自己的网络,本文将教你如何用pytorch实现。
首先看一下训练好的网络结构:
这是一个seq2seq网络,包含encoder和decoder两部分,每一部分都包含一个embedding层、一个LSTM层和一个Dropout层,decoder网络还有一个Linear层。
然后看一下新的网络结构:
同样是一个seq2seq的结构,区别是decoder网络里面只保留了一个Linear层,其他层都删掉了。那么怎么将训练好的网络参数复制过来呢?
1、保存预训练好的网络参数
torch.save(model.state_dict(), 'net-1.pth'