.pt
文件是模型权重,主要包含训练过程中保留的键值对信息
以自己训练的yolov5s
目标检测权重test.pt
为例:
1.模型文件加载
pt = r'weights/test.pt'
state = torch.load(pt)
2.模型键值对查询
for k in state.keys():
print(k)
epoch: 113(对后续推理无意义可删除)
best_fitness: [0.85533715] (对后续推理无意义可删除)
training_results: 0/599 2.47G 0.05426… (对后续推理无意义可删除)
model: 模型权重
ema:模型权重 (重复可删除)
updates: 23620 (对后续推理无意义可删除)
optimizer:优化器的一些参数 (只用作推理可删除)
wandb_id:可视化序列值 (可删除)
3.删除无效键值对(del)
del state['optimizer']['state']
del state['training_results']
del state['ema']
del state['updates']
4.查询键值类型(type)
print(type(state['model'])) <class 'models.yolo.Model'>
print(type(state['updates'])) <class 'int'>
print(type(['wandb_id'])) <class 'list'>
print(type(state['optimizer'])) <class 'dict'>
5.模型文件保存
state:加载的模型权重
filename:存储的文件名
torch.save(state, filename)
去掉无用的信息后,pt模型体积会减小很多:
6.实例代码
以yolov5模型为例:
import torch
pt = r'weights/best.pt'
# 加载
state = torch.load(pt)
# 键值对
for k in state.keys():
print(k)
del state['epoch']
del state['best_fitness']
del state['training_results']
del state['ema']
del state['updates']
del state['optimizer']
del state['wandb_id']
# 保存
torch.save(state, 'weights/best_cut.pt')
for k in state.keys():
print(k)