查看信息
import torch
weights = torch.load('./results/0.pkl')
# 打印出所有的键
for key in weights:
print(key)
# 查看特定参数的形状
print(weights['conv1.weight'].size())
自定义保存额外信息
1. 自定义保存额外信息
在训练时,可以选择以某种方式手动保存这些元数据。例如,在PyTorch中,除了保存模型的state_dict,你还可以保存一个包含训练元数据的字典,如下所示:
# 假设 model 是你的模型实例,optimizer 是你的优化器实例
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'learning_rate': learning_rate,
# 可以添加更多需要的信息
}, 'model_with_metadata.pth')
加载这个文件时,你可以读取这些额外的信息:
checkpoint = torch.load('model_with_metadata.pth')
epoch = checkpoint['epoch']
learning_rate = checkpoint['learning_rate']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
2. 查看优化器状态
在某些情况下,如果你保存了优化器的状态(optimizer.state_dict()),你可以从中获取到某些训练参数,比如学习率。但是,这仍然不会直接告诉你总的训练epochs数,只能提供在保存时刻的优化器状态,如当前学习率:
optimizer_state = checkpoint['optimizer_state_dict']
current_lr = optimizer_state['param_groups'][0]['lr']
本文介绍了如何在PyTorch中自定义保存模型状态、优化器状态以及额外训练元数据,包括epoch数和学习率,并演示了如何加载这些信息进行恢复。
12万+

被折叠的 条评论
为什么被折叠?



