Android(4) —— 图像分类的*.tfile的使用 Classify.java

本文介绍如何在Android上使用TensorFlow Lite进行神经网络预测,包括模型加载、图片预处理、预测及结果解析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

当前,深度学习有很多框架:tensorflow、pytorch、caffe、keras等。很多场景下,需要训练好的模型在移动端运行。移动端的框架又有很多TensorFlow Lite、Core ML、NCNN、MNN等等。
其中 tensorflow 所对应的移动端移植框架 TensorFlow Lite。在自己这个系列记录使用 调用tfile进行神经网络预测的android的实现。

整个记录为:
Android(1) —— Android studio 开发环境搭建
Android(2) —— Android Studio找不到连接的手机
Android(3) —— 环境配置、手机端界面设计
Android(4) —— 图像分类的*.tfile的使用 Classify.java
Android(5) —— 安卓机通过相机或相册获取图片PhotoUtil.java
Android(6) —— 主函数的详解 MainActivity.java

1 代码讲解

1.1 代码概述

该脚定义了使用java读取 神经网络的移动端的tfile模型,并通过神经网络预测图片,得到最终的预测结果。是通过定义了public class ClassifyLib 来实现的。
(神经网络实现任务以10分类为例)

一般流程

  1. 加载网络模型
  2. 读取数据(视情况而定进行图片预测里),并进行神经网络预测
  3. 解析预测结果

先定义了类内的所需变量

public class ClassifyLib {
   //输入图片的大小,根据自己的实际情况修改
   private int[] ddims = {1, 3, 224, 224};
   //用于加载预测图片的标签,最终实际预测时,忽略这一项即可
   private List<String> resultLabel = new ArrayList<>();
   //新建一个全局 解译器对象,用来加载模型、运行模型、释放模型
   private Interpreter tflite = null;
   //修改为自己的模型名称
   private String modelname = "1999";

1.2 加载分类模型


从资源文件中读取模型文件,Google官方提供的读取tfile文件的方法,与Interpreter 配合使用,常规操作

private MappedByteBuffer loadModelFile( Context context) throws IOException {
   //获取通过openFd()的方法获取asset目录下指定文件的 AssetFileDescriptor对象
   AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelname + ".tflite");

   //返回可用于读取文件中的数据的FileDescriptor对象
   FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
   FileChannel fileChannel = inputStream.getChannel();//返回与此文件输入流关联的通道
   long startOffset = fileDescriptor.getStartOffset();//返回asset中项的数据开始处的字节偏移量
   long declaredLength = fileDescriptor.getDeclaredLength();//返回构造AssetFileDescriptor时声明的实际字节数

   //map方法来把文件影射为内存映像文件 把文件的从position开始的size大小的区域映射为内存映像文件,
   // mode指出了 可访问该内存映像文件的方式:READ_ONLY,READ_WRITE,PRIVATE
   return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);//
}

  1. AssetFileDescriptor 资源文件描述器,用于读取文件中的数据
  2. JAVA中针对文件的读写操作设置了一系列的流,其中主要有FileInputStream、FileOutputStream、FileReader、FileWriter四种最为常用的流。
    FileInputStream 流被称为文件字节输入流,意思指对文件数据以字节的形式进行读取操作如读取图片视频等
  3. FileChannel 配合着ByteBuffer,将读写的数据缓存到内存中,然后以批量/缓存的方式read/write,
    省去了非批量操作时的重复中间操作,操纵大文件时可以显著提高效率
    详细的说:对于文件的复制,平时我们都是使用输入输出流进行操作,利用源文件创建出一个输入流,然后利用目标文件创建出一个输出流,最后将输入流的数据读取写入到输出流中。这样也是可以进行操作的。但是利用fileChannel是很有用的一个方式。它能直接连接输入输出流的文件通道,将数据直接写入到目标文件中去。而且效率更高。FileChannel 其实就是为了让文件快速复制,而获取的一个快速通道。
  4. *.map 是把文件映射为内存映像文件

执行模型读取

public boolean load_model(Context context) {
   try {
       Interpreter.Options options = new Interpreter.Options();
       options.setNumThreads(4); // 4线程运行
       tflite = new Interpreter(loadModelFile(context), options); // 进行模型读取
       return true;

   } catch (IOException e) {
       //e.printStackTrace();
       return false;
   }
}

1.3 读取图片并进行预处理

在Java中当我们要对数据进行更底层的操作时,一般是操作数据的字节(byte)形式,一个字节共8个二进制位,取值范围 [-128,127],或者[0, 255]。

【该函数的操作】

  1. 分配ByteBuffer,size为神经网络输入图片的大小
  2. 将输入的图片 bitmap缩放,size为神经网络输入图片的大小
  3. 对缩放后的图片进行解析和预处理,然后放入ByteBuffer中。(其中预处理的操作过程如果封装到了tfile中,这里就不需要操作了)
   private static ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) {

       //step1:每个像素点的三个分量都是一个int类型 java中int占4个字节 因此我们预先分配ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4
       ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4);
       imgData.order(ByteOrder.nativeOrder());  // imgData的字节序按照当前机器使用的字节序

       // get image pixel,size为图片的大小,这里是224x224
       int[] pixels = new int[ddims[2] * ddims[3]];

       //step2:将原图片按照ddims[2], ddims[3] 进行缩放 filter决定是否平滑
       Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false);

       //将bm中的每个像素颜色转为int值存入pixels 每隔bm.getWidth()个像素换一行 每个值都是一个十进制 有时候还是负数
       bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[2], ddims[3]);

       // val -15001066 tmp1 =27 tmp2=26 tmp3 = 22
       int pixel = 0;
       for (int i = 0; i < ddims[2]; ++i) {
           for (int j = 0; j < ddims[3]; ++j) {
               final int val = pixels[pixel++];
               // tmp1、tmp2、tmp3 是为了调试阶段,监测数值而定义的
               int tmp1 = ((val >> 16) & 0xFF);
               int tmp2 = ((val >> 8) & 0xFF);
               int tmp3 = ((val ) & 0xFF);
               //预处理非常重要
               imgData.putFloat(((val >> 16) & 0xFF) );
               imgData.putFloat(((val >> 8) & 0xFF) );
               imgData.putFloat((val & 0xFF) );
//                imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f));
//                imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f));
//                imgData.putFloat((((val & 0xFF) - 128f) / 128f))
           }
       }

       if (bm.isRecycled()) {
           bm.recycle();
       }
       return imgData;
   }

1.4 预测图片分类 并后处理

【该函数的操作】

  1. 获取输入图片的 ByteBuffer(并转换为 Object[] 的格式,用于多输入输出的预测API)
  2. 创建数组,来存放预测的结果(数组转化为Object[] 的格式,用于多输入输出的预测API)
  3. tfile.run() 或者 runForMultipleInputsOutputs()
  4. 解析预测结果
public String predict_image(Bitmap bmp) {
   boolean A = true; //自己测试使用的一个标志位
   
   ByteBuffer inputData = getScaledMatrix(bmp, ddims);
   String show_text = "" ;
   try {
       if (A) {
           //单输入输出、tflite.run的使用
           
           //原始模型输出的结果为10类 1行10列 属于10类的概率值 因此new一个float类型的数组用来存放run之后的结果
           final float[][] labelProbArray = new float[1][10];
   
           long start = System.currentTimeMillis();
           tflite.run(inputData, labelProbArray);//运行模型
           long end = System.currentTimeMillis();
           long time = end - start;
   
           float[] results = new float[labelProbArray[0].length];
           //为了方便计算 将第一行的10个概率值拷贝到一个一维数组results
           System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length);
           int r = get_max_result(results);//输出数据后处理
           show_text = "result:" + r + "\nname:" + resultLabel.get(r) + "\nprobability:" + results[r] + "\ntime:" + time + "ms";
       }
       else {
           //多输入输出、tflite.runForMultipleInputsOutputs的使用
           
           //原始模型输出的结果为10类 1行10列 属于10类的概率值 因此new一个float类型的数组用来存放run之后的结果
           //float[][] labelProbArray = new float[1][10];
           final float[][] labelProbArray = new float[1][10]; //为输出数据建立与模型输出结果对应的变量
           Object[] inputArray = {inputData};
           Map<Integer, Object> outputMap = new HashMap(){
               {put(0,labelProbArray);}
           };
           long start = System.currentTimeMillis();
           tflite.runForMultipleInputsOutputs(inputArray,outputMap);
           long end = System.currentTimeMillis();
           long time = end - start;
   
           float[] results = new float[((float[][])outputMap.get(0))[0].length];
           //为了方便计算 将第一行的10个概率值拷贝到一个一维数组results
           System.arraycopy(((float[][])outputMap.get(0))[0], 0, results, 0, ((float[][])outputMap.get(0))[0].le    ngth);
           int r = get_max_result(results);//输出数据后处理
           show_text = "result:" + r + "\nname:" + resultLabel.get(r) + "\nprobability:" + results[r] + "\ntime:" + time + "ms";
       }

   } catch (Exception e) {
       e.printStackTrace();
   }
   return show_text;
}

解析预测结果:对于分类神经网络的输出,是预测各类的概率。只需要遍历,找出输出向量中最大值的索引即可

private int get_max_result(float[] result) {
   float probability = result[0];
   int r = 0;
   //通过一个循环找到概率值最大的类别索引
   for (int i = 0; i < result.length; i++) {
       if (probability < result[i]) {
           probability = result[i];
           r = i;
       }
   }
   return r;
}

1.5 模型释放与关闭

public void close(){

   tflite.close();

}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值