pt模型转tflite格式的血与泪

前言:用yolov8n训练完手势数据后,需要用到安卓移动端,所以需要转为tflite格式,不清楚各种格式的可以看我另一篇文章

.pt文件转成.tflite文件的原理:是将.pt文件转换成一种中间表示形式.onnx后再转换成.tflite格式,具体流程是:.pt => .onnx => .pb => .tflite

建议:能直接训练出来.tflite还是用tensorflow训练

成功转换流程如下

0、用conda创建虚拟环境

每个项目都需要有一套conda环境,本次采用python=3.10

conda create -n myenv3.10 python=3.10 -y
conda activate myenv3.10

1、安装 tensorflow

安装的是 2.8.0 版本

pip install tensorflow==2.8.0

2、接下来安装 tensorflow-addons

TensorFlow Addons(简称 tfa)是 TensorFlow 生态系统中的一个官方扩展库,提供了一系列补充功能和组件,帮助开发者更高效地构建和训练复杂的机器学习模型。它与 TensorFlow 的关系可以概括为:TensorFlow Addons 是 TensorFlow 的官方扩展,提供了额外的功能模块,但不属于核心TensorFlow 库。

安装版本为

pip install tensorflow-addons==0.18.0

其他版本可以去这里找: https://gitcode.com/tensorflow/addons/overview

3、根据tensorflow安装ONNX和ONNX-Tensorflow

安装版本为

pip install onnx-tf==1.10.0
pip install onnx==1.10.2

其他版本可以去这里找: https://github.com/onnx/onnx-tensorflow/tree/main

4、根据ONNX安装onnx Runtime和onnx opset

安装版本为,这里注意,onnxruntime官方对应的是1.10,但是会显示最低版本为1.12。

pip install onnxruntime==1.12
# pip install onnx-opset==15

其他版本可以去这里找: https://onnxruntime.ai/docs/reference/compatibility.html

5、如果numpy报错

安装1.22版本

pip install numpy==1.22

6、pt转onnx

import torch
model = torch.load("best.pt", map_location="cpu").eval()  # 加载模型并设为评估模式
dummy_input = torch.randn(1, 3, 640, 640)  # 示例输入,尺寸需匹配模型
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    opset_version=13,  # YOLOv8推荐opset 13
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}  # 支持动态批次
)

在这里插入图片描述

7、onnx转pb

如果报错IR版本太高先转格式

import onnx

# 加载 ONNX 模型
model = onnx.load("D:\\xxx\\best.onnx")

# 查看当前 IR 版本
print(f"Current IR version: {model.ir_version}")

# 修改 IR 版本(例如降级到 8)
model.ir_version = 8

# 保存修改后的模型
onnx.save_model(model, "modified_model.onnx")

onnx_to_pb

import onnx
from onnx_tf.backend import prepare

TF_PATH = r"D:\\xxx\\tf_model"
ONNX_PATH = r"D:\\xxx\\modified_model.onnx"
onnx_model = onnx.load(ONNX_PATH)
tf_rep = prepare(onnx_model) 
tf_rep.export_graph(TF_PATH)

在这里插入图片描述

8、pb转tflite

import tensorflow as tf

saved_model_dir = r"D:\\xxx\\tf_model"
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, 
  tf.lite.OpsSet.SELECT_TF_OPS ]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

在这里插入图片描述

相关报错如下

1 各种依赖不兼容报错,根据实际报错解决

2 安装onnx时少cmake

去官网下载安装 https://cmake.org/download/,安装时添加环境变量
在终端窗口验证

cmake --version  # 应输出 CMake 版本(如 4.0.2)

在这里插入图片描述

AssertionError: Could not find "cmake" executable!
      [end of output]

  note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sky丶Mamba

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值