pytorch 移动端框架 thnets 附c示例代码

本文分享了一个移动端图像识别项目的实践经验,对比了mxnet、thnets、caffe及tensorflow等框架的使用心得,重点介绍了thnets在iOS下的应用及性能优化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前年年前做一个手机移动端图像识别项目的时候,

先后尝试了mxnet,thnets,caffe,tensorflow.

当时的情况是,mxnet内存管理奇差,内存经常由于模型运算分配不足,app挂掉。

后来调研了下caffe发现也很不友好。

最后发现thnets相对比较轻巧,

经过算法调优之后,性能还不错,

特别是在ios下启用了Accelerate加速库。

后来tensorflow快速发展,就切到tensorflow上了。

最近看了下thnets,作者 mvitez 看来不怎么上心了。

至今为止,改进的不多。

thnets的移动端样例代码,可以参考:

https://github.com/e-lab/apps-iOs

https://github.com/e-lab/apps-android

有一段时间nnpack加速库起来了,就想着把thnets给patch一下nnpack.

但是由于项目太赶,没那个时间去做。

后来也因为切换到tensorflow上了。

thnets就被雪藏了。

向作者提交了两个建议,1,改用stb_image加载图片  2, 支持windows平台

这个两个工作,我都做了。

作者合了1。

2 我关闭了。几天前去看历史记录,作者当时问我关闭的原因,我没回。

真正的原因是。。。thnets被我遗忘了,而windows 版本的存在意义并不大。

今天稍微花了点时间,在windows写个thnets的demo样例,给有需要的网友~

项目地址:

https://github.com/cpuimage/thnets

代码示例见:demo.c

#include <string.h> 
#include <stdlib.h> 
#include <stdio.h> 
#include "thnets.h"
// http://cpuimage.cnblogs.com/
THNETWORK * net;

char * labels[] = { "lamp"," bottle"," watch"," pen"," grass"," shoe"," wall"," chair"," mug"," fork"," table"," book"," tablet"," bookcase"," pencil"," door"," face"," ceiling"," sofa"," bicycle"," aluminum - can"," window"," road"," stairs"," floor"," painting"," toy"," remote"," computer"," plant"," television"," dog"," laptop"," microwave"," cat"," tree"," knife"," car"," motorcycle"," person"," cup"," sidewalk"," telephone"," spoon"," hand"," sofabed" };

int main(int argc, char ** argv) {
    img_t image = { 0 };
    //test.jpg 
    char * pic_file = argv[1];
    //model
    char * model_path = argv[2];

    int dropclassifier = 0;
    loadimage(pic_file, &image);
    THInit();
    printf("init_thnets.");
    net = THLoadNetwork(model_path);
    if (net) {
        THUseSpatialConvolutionMM(net, 2);
        if (dropclassifier == 1) {
            if (net->net->modules[net->net->nelem - 1].type == MT_SoftMax)
                net->net->nelem--;
            if (net->net->modules[net->net->nelem - 1].type == MT_Linear)
                net->net->nelem--;
            if (net->net->modules[net->net->nelem - 1].type == MT_View)
                net->net->nelem--;
        }
    }
    else {
        printf("Shiiiiit went down.");
        return -1;
    }
    float * percentages = 0;
    int outwidth, outheight;

    if (!net) {
        return 0;
    }
    int i = 0;
    int nbatch = 1;
    unsigned char *bitmaps[256];
    for (i = 0; i < nbatch; i++)
        bitmaps[i] = image.bitmap;
    int size = THProcessImages(net, bitmaps, nbatch, image.width, image.height, image.cp * image.width, &percentages, &outwidth, &outheight, 0);
    for (i = 0; i < nbatch; i++)
        if (bitmaps[i])
            free(bitmaps[i]);
    if (percentages)
    {
        float max[3] = { 0 };
        int maxPos[3] = { 0 };
        for (int j = 0; j < size; j++) {
            if (percentages[j] > max[0]) {
                maxPos[0] = j;
                max[0] = percentages[j];
            }
            else if (percentages[j] > max[1]) {
                maxPos[1] = j;
                max[1] = percentages[j];
            }
            else if (percentages[j] > max[2]) {
                maxPos[2] = j;
                max[2] = percentages[j];
            }
        }
        for (int index = 0; index < 3; index += 1) {
            const float predictionValue = max[index];
            if (predictionValue > 0.05f) {
                printf(" %f %s  \n", predictionValue, labels[maxPos[index] % size]);
            }
        }
        free(percentages);
    }
    getchar();
}

 

采用vs编译的话,需要下载openblas:

https://jaist.dl.sourceforge.net/project/openblas/v0.2.19/

lib库链接文件:libopenblas.dll.a 即可编译。

附带的模型 来自e-lab的项目。

示例输出得分最高并且高于0.05的三个结果。

示例图片:test.jpg

示例模型:  model

对应的标签,见代码或模型文件夹下的categories.txt

以上,待有精力再对thnets进行性能调优。

对于前向传播而言,最好的代码学习资料莫过于:darknet

https://github.com/pjreddie/darknet

darknet 代码写得十分耐看,逻辑清晰。

darknet+nnpack:

https://github.com/digitalbrain79/darknet-nnpack

 

若有其他相关问题或者需求也可以邮件联系俺探讨。

邮箱地址是: 
gaozhihan@vip.qq.com

 

转载于:https://www.cnblogs.com/cpuimage/p/8366792.html

### PyTorch 移动端部署教程 在移动端部署 PyTorch 模型需要经过模型转换、优化和集成到移动应用的过程。以下是关于如何实现这一目标的详细说明: #### 1. 模型转换为 TorchScript PyTorch 提供了两种将模型转换为 TorchScript 的方式:**追踪(Tracing)** 和 **脚本化(Scripting)**。这两种方法都可以生成可以在移动端运行的模型。 ```python import torch from torchvision import models # 加载预训练模型 model = models.resnet18(pretrained=True) model.eval() # 创建示例输入 example_input = torch.rand(1, 3, 224, 224) # 使用追踪方法生成 TorchScript 模型 traced_model = torch.jit.trace(model, example_input) traced_model.save("resnet18_traced.pt") # 保存模型[^1] # 使用脚本化方法生成 TorchScript 模型 scripted_model = torch.jit.script(model) scripted_model.save("resnet18_scripted.pt") # 保存模型[^2] ``` #### 2. 模型优化 为了提高移动端性能,可以对模型进行量化处理。量化会将浮点数权重转换为较低精度的整数,从而减少内存占用并加速推理。 ```python # 示例:后训练静态量化 quantized_model = torch.quantization.quantize_dynamic( model, # 模型 {torch.nn.Linear}, # 量化的目标层类型 dtype=torch.qint8 # 目标数据类型 ) quantized_model.eval() torch.jit.save(torch.jit.script(quantized_model), "resnet18_quantized.pt") # 保存量化模型[^3] ``` #### 3. 集成到移动端应用 PyTorch 提供了官方库 `libtorch` 和工具链来支持 Android 和 iOS 平台上的模型部署。 - **Android** 在 Android 上,可以通过使用 C++ 或 Java API 来加载和运行模型。以下是一个简单的 Java 代码示例: ```java import org.pytorch.*; public class MainActivity extends AppCompatActivity { @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); // 加载模型 Module module = Module.load(assetFilePath(this, "resnet18_scripted.pt")); // 准备输入张量 Tensor inputTensor = Tensor.fromBlob(new float[1 * 3 * 224 * 224], new long[]{1, 3, 224, 224}); // 执行推理 Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); } private static String assetFilePath(Context context, String assetName) { File file = new File(context.getFilesDir(), assetName); if (file.exists() && file.length() > 0) { return file.getAbsolutePath(); } try (InputStream is = context.getAssets().open(assetName)) { file.createNewFile(); FileOutputStream os = new FileOutputStream(file); byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.close(); } catch (IOException e) { throw new RuntimeException(e); } return file.getAbsolutePath(); } } ``` - **iOS** 在 iOS 上,可以使用 Swift 或 Objective-C 来加载和运行模型。以下是一个 Swift 代码示例: ```swift import PyTorchMobile let module = try! MobileModule(bundle: .main, name: "resnet18_scripted") // 创建随机输入张量 let inputTensor = Tensor<Float>(randomIn: (-1...1), shape: [1, 3, 224, 224]) // 执行推理 let outputTensor = try! module.run(withInputs: [inputTensor]) ``` #### 4. 性能调优 在移动端部署时,还需要考虑设备的硬件限制。以下是一些优化建议: - 使用 GPU 加速推理(如果设备支持)。 - 尽可能减少模型大小,例如通过剪枝或知识蒸馏。 - 确保模型输入符合预期尺寸以避免额外的计算开销[^4]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值