YOLOv11.pt 模型转换为 TFLite 和 NCNN 模型

鉴于 Windows 的兼容性问题,强烈建议使用 Google Colab,因为它提供 Linux 环境,预装 CUDA,兼容 ai_edge_litert。只需上传模型和脚本,安装依赖即可完成转换。

模型转换

打开 Google Colab(https://colab.research.google.com)

创建一个新笔记本并上传训练好的 yolo11n.pt 模型。

安装依赖:

pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --extra-index-url https://download.pytorch.org/whl/cu118
pip install ultralytics>=8.3.15 opencv-python>=4.6.0 tensorflow==2.18.0 tf_keras==2.18.0 onnx==1.17.0 onnx2tf>=1.26.3 sng4onnx>=1.0.1 onnx_graphsurgeon>=0.3.26 sympy>=1.13.3 protobuf>=5.26.1 onnxslim>=0.1.59

运行转换脚本:

from ultralytics import YOLO

# 加载模型
model = YOLO("best1.pt")

# 导出为 TFLite
model.export(format="tflite", imgsz=320)  # 创建 'yolo11n_float32.tflite'

# 导出为 NCNN
model.export(format="ncnn", imgsz=320)  # 创建 './yolo11n_ncnn_model'

下载生成的 .tflite 和 NCNN 模型文件。

TFLite 模型推理

  • 使用 tflite_runtime.interpreter 加载导出的 TFLite 模型。

  • 预处理输入图像(调整大小、归一化、格式转换)。

  • 执行推理并获取输出张量。

  • 输出需要进一步后处理以解析检测结果(如边界框、类别、置信度)。

python端推理代码:
# TFLite 模型推理
def tflite_infer(model_path, img_path):
    # 加载导出的 TFLite 模型进行推理
    tflite_model = YOLO(model_path)
    results_tflite = tflite_model(img_path)
    # 提取结果
    output_result(results_tflite)
    # 检测结果可视化并保存
    save_path = Path(img_path).with_name('detected.jpg')
    # 使用 ultralytics 自带的画框(已带标签)
    annotated = results_tflite[0].plot(labels=True)   # labels=True 会显示类别名
    cv2.imwrite(str(save_path), annotated)
    print(f'结果图已保存: {save_path}')

# 提取结果
def output_result(results):
    boxes = results[0].boxes

    # 构造 DataFrame
    df = pd.DataFrame({
        "similar": boxes.conf.cpu().numpy(),          # 置信度
        "rect": [b.tolist() for b in boxes.xyxy.cpu().numpy()],  # [x1,y1,x2,y2]
        "class_id": boxes.cls.cpu().numpy().astype(int)
    })

    # 按相似度降序排序
    df_sorted = df.sort_values("similar", ascending=False)
    print(df_sorted)
android端推理代码:
package com.magicianguo.mediaprojectiondemo.service;

import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.util.Log;
import androidx.annotation.NonNull;

import org.tensorflow.lite.Interpreter;

import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

public class YoloV11TFLiteDetector {
    private Interpreter tflite;
    private static final String[] names = {
            "Battle_Identification", "Can_Be_Purchased", "Chess_Pieces", "Current_Player",
            "Fetters", "Player", "Player_Character", "Prepare_For_War",
            "Preparing_Chess_Piece_Area", "Preparing_For_Chess_Pieces", "Rune", "Store"
    };
    private final float threshold = 0.7f;
    private final Context context;
    private final int inputSize = 320;

    // Detection result class
    public static class Detection {
        public float x1, y1, x2, y2; // Bounding box coordinates
        public float conf;           // Confidence score
        public int classId;          // Class ID

        public Detection(float x1, float y1, float x2, float y2, float conf, int classId) {
            this.x1 = x1;
            this.y1 = y1;
            this.x2 = x2;
            this.y2 = y2;
            this.conf = conf;
            this.classId = classId;
        }

        @NonNull
        @Override
        public String toString() {
            return String.format(Locale.US, "Detection: {classId=%d, className=%s conf=%.2f, rect=(%.2f, %.2f, %.2f, %.2f)}",
                    classId, YoloV11TFLiteDetector.names[classId], conf, x1, y1, x2, y2);
        }
    }

    public YoloV11TFLiteDetector(Context context, String modelPath) {
        this.context = context;
        try {
            tflite = new Interpreter(loadModelFile(modelPath), new Interpreter.Options().setNumThreads(1));
            Log.i("YoloV11TFLiteDetector", "Model loaded successfully");
        } catch (IOException e) {
            Log.e("YoloV11TFLiteDetector", "Failed to load model", e);
        }
    }

    // Load TFLite model from assets
    private MappedByteBuffer loadModelFile(String modelPath) throws IOException {
        AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelPath);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }

    // Detect objects in the image
    public List<Detection> detect(Bitmap bitmap) {
        if (bitmap == null) {
            Log.e("YoloV11TFLiteDetector", "Provided Bitmap is null");
            return new ArrayList<>(); // Return empty list if bitmap is null
        }
        int width = bitmap.getWidth();   // 宽度(像素)
        int height = bitmap.getHeight(); // 高度(像素)
        Log.i("YoloV11TFLiteDetector", "width:"+width+",height:"+height);
        // Load and preprocess image
        Bitmap resizedBitmap = Bitmap.createScaledBitmap(bitmap, inputSize, inputSize, true);
        float[][][][] inputImage = bitmapToFloatArray(resizedBitmap);

        // Run inference
        float[][][] output = new float[1][16][2100];
        tflite.run(inputImage, output);

        // Parse output
        List<Detection> detections = parseOutput(output[0]);
        for (Detection detection : detections) {
            detection.x1 *= width;
            detection.y1 *= height;
            detection.x2 *= width;
            detection.y2 *= height;
        }
        return detections;
    }

    // Convert Bitmap to float array for model input
    private float[][][][] bitmapToFloatArray(Bitmap bitmap) {
        float[][][][] inputImage = new float[1][inputSize][inputSize][3];
        for (int y = 0; y < inputSize; y++) {
            for (int x = 0; x < inputSize; x++) {
                int pixel = bitmap.getPixel(x, y);
                // 若模型需要BGR,交换R和B通道(根据训练数据格式调整)
                inputImage[0][y][x][0] = ((pixel & 0xFF)) / 255.0f; // B
                inputImage[0][y][x][1] = ((pixel >> 8) & 0xFF) / 255.0f;  // G
                inputImage[0][y][x][2] = ((pixel >> 16) & 0xFF) / 255.0f; // R
            }
        }
        return inputImage;
    }

    // Parse model output
    private List<Detection> parseOutput(float[][] output) {
        List<Detection> detections = new ArrayList<>();
        int numDetections = output[0].length; // 2100个检测框
        int attributesPerDetection = output.length; // 16个属性(4坐标+12类别)

        for (int i = 0; i < numDetections; i++) {
            // 1. 解析归一化坐标(x1, y1, x2, y2)
            float x = output[0][i];
            float y = output[1][i];
            float w = output[2][i];
            float h = output[3][i];

            float x1 = x-w/2;
            float x2 = x+w/2;
            float y1 = y-h/2;
            float y2 = y+h/2;

            // 2. 解析12个类别的置信度(索引4-15)
            float[] classProbs = new float[12]; // 匹配元数据的12个类别
            for (int j = 0; j < 12; j++) {
                int index = 4 + j;
                if (index < attributesPerDetection) {
                    classProbs[j] = output[index][i];
                }
            }

            // 3. 提取最大置信度和对应类别
            int classId = argmax(classProbs);
            float conf = classProbs[classId];

            // 4. 过滤低置信度结果
            if (conf > threshold) {
                Log.i("YoloV11TFLiteDetector", String.format("%.2f %.2f %.2f %.2f %.2f %d", x1, y1, x2, y2, conf, classId));
                detections.add(new Detection(x1, y1, x2, y2, conf, classId));
            }
        }
        return detections;
    }

    // 确保argmax方法正确处理12个类别
    private int argmax(float[] array) {
        int maxIdx = 0;
        for (int i = 1; i < array.length; i++) { // array长度为12
            if (array[i] > array[maxIdx]) {
                maxIdx = i;
            }
        }
        return maxIdx;
    }
}

修改 build.gradle 添加依赖

dependencies {
    implementation 'org.tensorflow:tensorflow-lite:2.9.0'
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.3.1'
    implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.9.0'
}

<think>好的,我现在需要帮助用户解决如何在VMware中安装合适的USB设备驱动的问题。首先,我需要理解用户的问题到底是什么。用户提到的是“VMware USB device在哪下载合适的驱动”,可能是指当他们在VMware虚拟机中使用USB设备时遇到了驱动问题,需要找到正确的驱动程序。 接下来,我要回忆一下VMware产品中关于USB设备支持的相关知识。VMware Workstation和Fusion等产品通常会自动处理USB设备的连接,但在某些情况下,可能需要安装额外的驱动,尤其是在客户机操作系统中。例如,如果客户机是Windows,可能需要安装VMware Tools来增强USB支持;如果是Linux,可能需要特定的驱动包。 然后,我需要考虑用户可能的具体情况。用户可能已经将USB设备连接到主机,但在虚拟机中无法识别,或者出现驱动错误。这时候需要确定是主机驱动的问题还是客户机驱动的问题。通常,主机不需要额外驱动,因为VMware会通过自身的服务处理USB设备的连接,而客户机可能需要安装VMware Tools来获得更好的兼容性。 接着,我需要思考正确的步骤来引导用户解决问题。首先,应该检查VMware Tools是否已经安装,这是解决很多设备兼容性问题的第一步。如果已经安装,可能需要更新到最新版本。其次,对于客户机操作系统,比如Windows,可能需要检查设备管理器中的USB控制器是否有感叹号,是否需要手动安装驱动。此外,还需要确认主机的USB控制器设置是否正确,比如在虚拟机设置中选择USB兼容性版本(如USB 3.0或2.0)。 如果用户的问题是关于主机系统的驱动,比如VMwareUSB控制器驱动在主机上出现问题,可能需要到VMware官网下载并重新安装VMware Workstation或Player,这样会自动安装所需的驱动。另外,某些情况下,主机的USB驱动可能需要更新,比如芯片组驱动或USB 3.0驱动,这需要到主机厂商的网站下载。 还需要考虑用户可能的误解。例如,用户可能认为需要单独下载USB驱动,但实际上,大多数情况下,VMware Tools已经包含了必要的驱动。如果用户已经安装了VMware Tools但仍然有问题,可能需要进一步的故障排除,比如检查USB设备的兼容性,尝试不同的USB端口,或者重启VMware的相关服务。 最后,整理出清晰的步骤,确保用户能够按照流程操作,同时提醒常见的注意事项,比如以管理员身份运行安装程序,确保虚拟机配置正确等。可能还需要提供VMware官网的链接,方便用户下载工具或更新版本。</think>在VMware虚拟机中使用USB设备时,通常不需要单独下载"VMware USB设备驱动",但需要按以下流程操作: ### 一、基础准备 1. **主机端** - 确保主机已安装最新版VMware Tools(VMware Workstation/Fusion自带) - 无需单独下载驱动,系统会自动加载`vmware-usbarbitrator`服务 2. **虚拟机配置** - 打开虚拟机设置 → USB控制器 → 勾选`USB兼容性`(建议选3.0或自动) ### 二、客户机驱动安装 1. **Windows客户机** - 安装VMware Tools(虚拟机菜单 → 管理 → 安装VMware Tools) - 插入USB设备后,会自动加载`vmware USBCCID`驱动程序 2. **Linux客户机** - 安装open-vm-tools: ```bash sudo apt-get install open-vm-tools-desktop ``` - 加载USB驱动模块: ```bash sudo modprobe uhci_hcd ehci_hcd ohci_hcd xhci_hcd ``` ### 三、常见问题排查 1. **设备无法识别** - 检查主机设备管理器 → 确保无`VMware USB Device`黄色感叹号 - 重启服务:`services.msc` → 重启`VMware USB Arbitration Service` 2. **特殊设备支持** - 加密狗/工控设备需在虚拟机设置 → USB控制器 → 开启`显示所有USB输入设备` - 对于USB转串口设备,建议在客户机安装对应芯片驱动(如FTDI、CH340) ### 四、驱动下载渠道(仅限特殊情况) 如需获取底层驱动文件,可通过: 1. VMware官网支持页面: ``` https://customerconnect.vmware.com/downloads ``` 2. 选择对应产品 → 驱动包通常包含在`VMware Tools Bundle`中 > **注意事项**: > - 虚拟机运行时才能看到USB设备连接选项 > - 苹果M系列芯片需使用ARM版Windows/Linux系统 > - 安卓设备需开启开发者模式+USB调试
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值