前言:用yolov8n训练完手势数据后,需要用到安卓移动端,所以需要转为tflite格式,不清楚各种格式的可以看我另一篇文章
.pt文件转成.tflite文件的原理:是将.pt文件转换成一种中间表示形式.onnx后再转换成.tflite格式,具体流程是:.pt => .onnx => .pb => .tflite
建议:能直接训练出来.tflite还是用tensorflow训练
成功转换流程如下
0、用conda创建虚拟环境
每个项目都需要有一套conda环境,本次采用python=3.10
conda create -n myenv3.10 python=3.10 -y
conda activate myenv3.10
1、安装 tensorflow
安装的是 2.8.0 版本
pip install tensorflow==2.8.0
2、接下来安装 tensorflow-addons
TensorFlow Addons(简称 tfa)是 TensorFlow 生态系统中的一个官方扩展库,提供了一系列补充功能和组件,帮助开发者更高效地构建和训练复杂的机器学习模型。它与 TensorFlow 的关系可以概括为:TensorFlow Addons 是 TensorFlow 的官方扩展,提供了额外的功能模块,但不属于核心TensorFlow 库。
安装版本为
pip install tensorflow-addons==0.18.0
其他版本可以去这里找: https://gitcode.com/tensorflow/addons/overview
3、根据tensorflow安装ONNX和ONNX-Tensorflow
安装版本为
pip install onnx-tf==1.10.0
pip install onnx==1.10.2
其他版本可以去这里找: https://github.com/onnx/onnx-tensorflow/tree/main
4、根据ONNX安装onnx Runtime和onnx opset
安装版本为,这里注意,onnxruntime官方对应的是1.10,但是会显示最低版本为1.12。
pip install onnxruntime==1.12
# pip install onnx-opset==15
其他版本可以去这里找: https://onnxruntime.ai/docs/reference/compatibility.html
5、如果numpy报错
安装1.22版本
pip install numpy==1.22
6、pt转onnx
import torch
model = torch.load("best.pt", map_location="cpu").eval() # 加载模型并设为评估模式
dummy_input = torch.randn(1, 3, 640, 640) # 示例输入,尺寸需匹配模型
torch.onnx.export(
model,
dummy_input,
"model.onnx",
opset_version=13, # YOLOv8推荐opset 13
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}} # 支持动态批次
)
7、onnx转pb
如果报错IR版本太高先转格式
import onnx
# 加载 ONNX 模型
model = onnx.load("D:\\xxx\\best.onnx")
# 查看当前 IR 版本
print(f"Current IR version: {model.ir_version}")
# 修改 IR 版本(例如降级到 8)
model.ir_version = 8
# 保存修改后的模型
onnx.save_model(model, "modified_model.onnx")
onnx_to_pb
import onnx
from onnx_tf.backend import prepare
TF_PATH = r"D:\\xxx\\tf_model"
ONNX_PATH = r"D:\\xxx\\modified_model.onnx"
onnx_model = onnx.load(ONNX_PATH)
tf_rep = prepare(onnx_model)
tf_rep.export_graph(TF_PATH)
8、pb转tflite
import tensorflow as tf
saved_model_dir = r"D:\\xxx\\tf_model"
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS ]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
相关报错如下
1 各种依赖不兼容报错,根据实际报错解决
2 安装onnx时少cmake
去官网下载安装 https://cmake.org/download/,安装时添加环境变量
在终端窗口验证
cmake --version # 应输出 CMake 版本(如 4.0.2)
AssertionError: Could not find "cmake" executable!
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error
× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> See above for output.
note: This error originates from a subprocess, and is likely not a problem with pip.