pytorch的C++接口libtorch的模型导出和解析

本文讲述了如何在C++中使用libtorch库从yolov5s模型转换到yolov5n,遇到预期输出类型不匹配的问题,并通过分析和调整模型导出及代码适配解决了问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

项目需求

根据GitHub上的开源代码,使用libtorch库,解析yolov5s.torchscript.pt文件实现了C++的yolov5s的目标检测,然而我想使用更小的yolov5n来替换yolov5s。

错误尝试

首先我配置环境,使用yolov5-v7.0的export.py通过权重文件yolov5n.pt导出yolov5n.torchscript(其实我想导出的是yolov5n.torchscript.pt文件,不知到怎么调参,结果没成功),将开源代码中的yolov5s.torchscript.pt替换为yolov5n.torchscript。结果报错,报错信息摘要:Expected Tuple but got String,意思是说模型解析后期望生成的数据类型是Tuple(元组)但是得到的是String

然后我又用同样的方法把yolov5s.pt导出yolov5s.torchscript文件,并替换,不出意外抛出同样的异常。

接着我又从优快云上下载了一个别人上传的一个现成的yolov5s.torchscript.pt文件,替换后报错:Expected Tuple but got GenericList  从别人博客中了解到可能原因是模型中没有使用GPU。

错误原因分析与解决方案

问题原因:导出的模型和代码不匹配

解决方案:

1. 去确定一下自己模型的输出到底是什么,然后在 C++ 代码中用适当的数据类型接收处理即可。(我不会)

2. 根据代码的数据要求导出模型,以下是导出代码:

import torch
from models.experimental import attempt_load

# 加载模型
model = attempt_load('yolov5s.pt')
model.eval()  # 开启评估模式

# 创建示例输入
input_data = torch.randn(1, 3, 640, 384)  # 根据实际情况修改输入形状

# 转换模型
traced_script_module = torch.jit.trace(model, input_data, check_trace=False)

# 保存
traced_script_module.save("yolov5s.torchscript.pt")

根据自己代码的数据要求更改input_date 

进行模型转换的时侯报错

在traced_script_module = torch.jit.trace(model, input_data)这行代码后加入参数check_trace=False 问题解决,参考链接

最后,把自己导出的yolov5n.torchscript.pt替换掉源代码中的yolov5s.torchscript.pt代码成功的跑起来了。

由于对深度学习还没达到入门的程度,有哪里理解错误的地方还请不吝赐教,感谢🙇‍

最后,附一些记录:

1. 什么是.torchscript.pt文件

.torchscript.pt模型文件是使用PyTorch框架中的torchscript模块将训练好的模型转化为混合前端(JIT)模式的文件。此模式将模型转换为一种高性能的序列化表达格式,可以在PyTorchC++Java和其他支持Torchscript的平台上进行部署和推理。

要生成yolov5s.torchscript.pt模型文件,以下是大致步骤:

  1. 安装PyTorchYOLOv5:首先需要安装PyTorchyolov5库。
  2. 下载预训练模型:从YOLOv5的官方GitHub仓库中下载预训练的yolov5s模型文件。
  3. 加载模型:使用PyTorch加载下载的预训练模型文件。
  4. 导出模型:使用torch.jit.trace函数将加载的模型转换为torchscript模式。该函数会将模型前向计算的图表达为脚本代码。
  5. 保存模型:使用torch.jit.save函数将转换后的模型保存为yolov5s.torchscript.pt模型文件。

2. yolov5s.torchscript.pt文件和yolov5s.torchscript文件有什么区别

yolov5s.torchscript.pt和yolov5s.torchscript可能是同一种类型的文件,只是文件扩展名不同而已。他们二者都指的是TorchScript的模型文件。
TorchScript是一个Python的子集,可以在优化后的环境中运行,如C++中。这对于将模型部署到生产环境中尤为重要。
通过PyTorch的JIT(Just In Time)编译器,我们可以将PyTorch模型转化为TorchScript模型,这个过程称为跟踪(tracing)或脚本化(scripting)。

3. 得到一个yolov5s.torchscript.pt文件怎样使用C++代码和libtorch库进行推理

要在C中使用libtorch库进行推理,首先需要安装libtorch库。然后,你可以使用以下的C代码为例子进行推理:

#include <torch/script.h>
#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
    if (argc != 2) {
        std::cerr << "usage: example-app <path-to-exported-script-module>\n";
        return -1;
    }

    torch::jit::script::Module module;
    try {
        module = torch::jit::load(argv[1]);
    } catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
    }

    std::cout << "model loaded\n";

    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({1, 3, 224, 224}));

    at::Tensor output = module.forward(inputs).toTensor();
    std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
}


在这个例子中,我们先加载模型,然后创建一个输入张量(在这里是一个1x3x224x224的张量),
然后将张量传递给模型进行前向传播,最后输出模型的输出。
注意,模型的输入和输出都需要使用torch::jit::IValue类型。

问:为什么创建1x3x224x224的张量,其他类型的不可以吗

创建1x3x224x224的张量主要是因为这是一种常见的输入数据形状,特别是对于处理图像的卷积神经网络。
在这里,1是批次大小(batch size),3是通道数(对应于RGB颜色通道),224x224是图像的高和宽。

对于其他形状的张量,是否可以作为输入,主要取决于模型的设计和期望的输入。
例如,如果你的模型是为了处理32x32的灰度图像设计的,那么你应该输入1x1x32x32的张量。
如果你尝试用一个形状不匹配的张量作为输入,你可能会得到一个错误。
所以,你应该根据你的模型的设计和需求来确定输入张量的形状。

4. 我从别人那里得到一个yolov5s.torchscript.pt文件,使用什么工具或者办法可以知道模型的设计和期望的输出

对于一个已经编译为TorchScript并保存为.pt文件的模型,直接查看模型设计和期望的输入输出可能会有一些困难,
因为一些高级别的信息可能在编译过程中被丢失。

但是,你可以通过以下方法获取一些基本信息:

加载模型并打印它。这将给出模型的结构信息,但可能不会包含所有细节。例如:

module = torch.jit.load('yolov5s.torchscript.pt')
print(module)
你可以通过执行模型并检查其输入和输出来推断模型的期望输入和输出。
例如,你可以尝试不同大小的输入,看模型是否报错,并检查返回的输出的形状以获取输出的信息。
对于YOLOv5模型,一般来说,它的输入是一个四维的张量,形状为(批量大小, 3, 宽度, 高度),
其中3对应于RGB图像的颜色通道。输出通常是一个预测张量,其中包含了物体的类别,以及物体在图像中的位置。
如果你想要详细的模型设计信息,最好的办法可能是联系模型的原始作者或者查阅模型的相关文档。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值