这里先只是搬运过来,https://github.com/zergtant/pytorch-handbook/blob/master/chapter4/4.1-fine-tuning.ipynb,还没测试,但觉得以后会用到。
根据提供的方法,需要经过如下步骤:
1)定义hook函数
in_list= [] #存放待输出中间层内容
def hook(module, input, output): #需要三元组输入
for i in range(input[0].size(0)):
in_list.append(input[0][i].cpu().numpy())
2)待输出中间层注册hook,比如原始例子注册了avgpool层
model.avgpool.register_forward_hook(hook)
3)前向跑起来并保存
with torch.no_grad():
for batch_idx, data in enumerate(dataloader):
x,y= data
y_ = model(x)
features=np.array(in_list)
np.save("features",features)