鉴于 Windows 的兼容性问题,强烈建议使用 Google Colab,因为它提供 Linux 环境,预装 CUDA,兼容 ai_edge_litert。只需上传模型和脚本,安装依赖即可完成转换。
模型转换
打开 Google Colab(https://colab.research.google.com)
创建一个新笔记本并上传训练好的 yolo11n.pt 模型。
安装依赖:
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --extra-index-url https://download.pytorch.org/whl/cu118
pip install ultralytics>=8.3.15 opencv-python>=4.6.0 tensorflow==2.18.0 tf_keras==2.18.0 onnx==1.17.0 onnx2tf>=1.26.3 sng4onnx>=1.0.1 onnx_graphsurgeon>=0.3.26 sympy>=1.13.3 protobuf>=5.26.1 onnxslim>=0.1.59
运行转换脚本:
from ultralytics import YOLO
# 加载模型
model = YOLO("best1.pt")
# 导出为 TFLite
model.export(format="tflite", imgsz=320) # 创建 'yolo11n_float32.tflite'
# 导出为 NCNN
model.export(format="ncnn", imgsz=320) # 创建 './yolo11n_ncnn_model'
下载生成的 .tflite 和 NCNN 模型文件。
TFLite 模型推理
-
使用 tflite_runtime.interpreter 加载导出的 TFLite 模型。
-
预处理输入图像(调整大小、归一化、格式转换)。
-
执行推理并获取输出张量。
-
输出需要进一步后处理以解析检测结果(如边界框、类别、置信度)。
python端推理代码:
# TFLite 模型推理
def tflite_infer(model_path, img_path):
# 加载导出的 TFLite 模型进行推理
tflite_model = YOLO(model_path)
results_tflite = tflite_model(img_path)
# 提取结果
output_result(results_tflite)
# 检测结果可视化并保存
save_path = Path(img_path).with_name('detected.jpg')
# 使用 ultralytics 自带的画框(已带标签)
annotated = results_tflite[0].plot(labels=True) # labels=True 会显示类别名
cv2.imwrite(str(save_path), annotated)
print(f'结果图已保存: {save_path}')
# 提取结果
def output_result(results):
boxes = results[0].boxes
# 构造 DataFrame
df = pd.DataFrame({
"similar": boxes.conf.cpu().numpy(), # 置信度
"rect": [b.tolist() for b in boxes.xyxy.cpu().numpy()], # [x1,y1,x2,y2]
"class_id": boxes.cls.cpu().numpy().astype(int)
})
# 按相似度降序排序
df_sorted = df.sort_values("similar", ascending=False)
print(df_sorted)
android端推理代码:
package com.magicianguo.mediaprojectiondemo.service;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.util.Log;
import androidx.annotation.NonNull;
import org.tensorflow.lite.Interpreter;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
public class YoloV11TFLiteDetector {
private Interpreter tflite;
private static final String[] names = {
"Battle_Identification", "Can_Be_Purchased", "Chess_Pieces", "Current_Player",
"Fetters", "Player", "Player_Character", "Prepare_For_War",
"Preparing_Chess_Piece_Area", "Preparing_For_Chess_Pieces", "Rune", "Store"
};
private final float threshold = 0.7f;
private final Context context;
private final int inputSize = 320;
// Detection result class
public static class Detection {
public float x1, y1, x2, y2; // Bounding box coordinates
public float conf; // Confidence score
public int classId; // Class ID
public Detection(float x1, float y1, float x2, float y2, float conf, int classId) {
this.x1 = x1;
this.y1 = y1;
this.x2 = x2;
this.y2 = y2;
this.conf = conf;
this.classId = classId;
}
@NonNull
@Override
public String toString() {
return String.format(Locale.US, "Detection: {classId=%d, className=%s conf=%.2f, rect=(%.2f, %.2f, %.2f, %.2f)}",
classId, YoloV11TFLiteDetector.names[classId], conf, x1, y1, x2, y2);
}
}
public YoloV11TFLiteDetector(Context context, String modelPath) {
this.context = context;
try {
tflite = new Interpreter(loadModelFile(modelPath), new Interpreter.Options().setNumThreads(1));
Log.i("YoloV11TFLiteDetector", "Model loaded successfully");
} catch (IOException e) {
Log.e("YoloV11TFLiteDetector", "Failed to load model", e);
}
}
// Load TFLite model from assets
private MappedByteBuffer loadModelFile(String modelPath) throws IOException {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
// Detect objects in the image
public List<Detection> detect(Bitmap bitmap) {
if (bitmap == null) {
Log.e("YoloV11TFLiteDetector", "Provided Bitmap is null");
return new ArrayList<>(); // Return empty list if bitmap is null
}
int width = bitmap.getWidth(); // 宽度(像素)
int height = bitmap.getHeight(); // 高度(像素)
Log.i("YoloV11TFLiteDetector", "width:"+width+",height:"+height);
// Load and preprocess image
Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, inputSize, inputSize, true);
float[][][][] inputImage = bitmapToFloatArray(resizedBitmap);
// Run inference
float[][][] output = new float[1][16][2100];
tflite.run(inputImage, output);
// Parse output
List<Detection> detections = parseOutput(output[0]);
for (Detection detection : detections) {
detection.x1 *= width;
detection.y1 *= height;
detection.x2 *= width;
detection.y2 *= height;
}
return detections;
}
// Convert Bitmap to float array for model input
private float[][][][] bitmapToFloatArray(Bitmap bitmap) {
float[][][][] inputImage = new float[1][inputSize][inputSize][3];
for (int y = 0; y < inputSize; y++) {
for (int x = 0; x < inputSize; x++) {
int pixel = bitmap.getPixel(x, y);
// 若模型需要BGR,交换R和B通道(根据训练数据格式调整)
inputImage[0][y][x][0] = ((pixel & 0xFF)) / 255.0f; // B
inputImage[0][y][x][1] = ((pixel >> 8) & 0xFF) / 255.0f; // G
inputImage[0][y][x][2] = ((pixel >> 16) & 0xFF) / 255.0f; // R
}
}
return inputImage;
}
// Parse model output
private List<Detection> parseOutput(float[][] output) {
List<Detection> detections = new ArrayList<>();
int numDetections = output[0].length; // 2100个检测框
int attributesPerDetection = output.length; // 16个属性(4坐标+12类别)
for (int i = 0; i < numDetections; i++) {
// 1. 解析归一化坐标(x1, y1, x2, y2)
float x = output[0][i];
float y = output[1][i];
float w = output[2][i];
float h = output[3][i];
float x1 = x-w/2;
float x2 = x+w/2;
float y1 = y-h/2;
float y2 = y+h/2;
// 2. 解析12个类别的置信度(索引4-15)
float[] classProbs = new float[12]; // 匹配元数据的12个类别
for (int j = 0; j < 12; j++) {
int index = 4 + j;
if (index < attributesPerDetection) {
classProbs[j] = output[index][i];
}
}
// 3. 提取最大置信度和对应类别
int classId = argmax(classProbs);
float conf = classProbs[classId];
// 4. 过滤低置信度结果
if (conf > threshold) {
Log.i("YoloV11TFLiteDetector", String.format("%.2f %.2f %.2f %.2f %.2f %d", x1, y1, x2, y2, conf, classId));
detections.add(new Detection(x1, y1, x2, y2, conf, classId));
}
}
return detections;
}
// 确保argmax方法正确处理12个类别
private int argmax(float[] array) {
int maxIdx = 0;
for (int i = 1; i < array.length; i++) { // array长度为12
if (array[i] > array[maxIdx]) {
maxIdx = i;
}
}
return maxIdx;
}
}
修改 build.gradle 添加依赖
dependencies {
implementation 'org.tensorflow:tensorflow-lite:2.9.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.3.1'
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.9.0'
}