这篇文章,用于记录将TransReID的pytorch模型转换为onnx的学习过程,期间参考和学习了许多大佬编写的博客,在参考文章这一章节中都已列出,非常感谢。
1. 在pytorch下使用ONNX主要步骤
1.1. 环境准备
安装onnxruntime包
安装教程可参考:
onnx模型预测环境安装笔记
onnxruntime配置
CPU版本:
直接pip安装
pip install onnxruntime
GPU版本:
先查看自己CUDA版本然后在下面的链接去找对应的onnxruntime的版本
CUDA版本的查询,可参考这个
onnxruntime版本查询
查询到对应版本,直接pip安装即可,例如
pip install onnxruntime-gpu==1.13.1
安装onnxsim包
pip install onnx-simplifier
1.2. 搭建 PyTorch 模型(TransReID)
def get_net(model_path,opt_=False):
if opt_:
cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
#cfg.freeze()
train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
else:
cfg.merge_from_file("/home/TransReID-main/configs/OCC_Duke/vit_transreid_stride.yml")
train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg)
net = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num)
#state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model_state_dict=net.state_dict()
for key in list(state_dict.keys()):
if key[7:] in model_state_dict.keys():
model_state_dict[key[7:]]=state_dict[key]
net.load_state_dict(model_state_dict)
return net
1.3. pytorch模型转换为 ONNX 模型
这个提供了静态转换(静态转换支持静态输入)和动态转换(动态转换支持动态输入)两个函数,可根据需要选择。
def convert_onnx_dynamic(model,save_path,simp=False):
x = torch.randn(4, 3, 256,128)
input_name = 'input'
output_name = 'class'
torch.onnx.export(model,x,save_path,input_names = [input_name