前言
当前,深度学习有很多框架:tensorflow、pytorch、caffe、keras等。很多场景下,需要训练好的模型在移动端运行。移动端的框架又有很多TensorFlow Lite、Core ML、NCNN、MNN等等。
其中 tensorflow 所对应的移动端移植框架 TensorFlow Lite。在自己这个系列记录使用 调用tfile进行神经网络预测的android的实现。
整个记录为:
Android(1) —— Android studio 开发环境搭建
Android(2) —— Android Studio找不到连接的手机
Android(3) —— 环境配置、手机端界面设计
Android(4) —— 图像分类的*.tfile的使用 Classify.java
Android(5) —— 安卓机通过相机或相册获取图片PhotoUtil.java
Android(6) —— 主函数的详解 MainActivity.java
1 代码讲解
1.1 代码概述
该脚定义了使用java读取 神经网络的移动端的tfile模型,并通过神经网络预测图片,得到最终的预测结果。是通过定义了
public class ClassifyLib
来实现的。
(神经网络实现任务以10分类为例)
【一般流程】
- 加载网络模型
- 读取数据(视情况而定进行图片预测里),并进行神经网络预测
- 解析预测结果
先定义了类内的所需变量
public class ClassifyLib { //输入图片的大小,根据自己的实际情况修改 private int[] ddims = {1, 3, 224, 224}; //用于加载预测图片的标签,最终实际预测时,忽略这一项即可 private List<String> resultLabel = new ArrayList<>(); //新建一个全局 解译器对象,用来加载模型、运行模型、释放模型 private Interpreter tflite = null; //修改为自己的模型名称 private String modelname = "1999";
1.2 加载分类模型
从资源文件中读取模型文件,Google官方提供的读取tfile文件的方法,与Interpreter 配合使用,常规操作
private MappedByteBuffer loadModelFile( Context context) throws IOException { //获取通过openFd()的方法获取asset目录下指定文件的 AssetFileDescriptor对象 AssetFileDescriptor fileDescriptor = context.getAssets().openFd(modelname + ".tflite"); //返回可用于读取文件中的数据的FileDescriptor对象 FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel();//返回与此文件输入流关联的通道 long startOffset = fileDescriptor.getStartOffset();//返回asset中项的数据开始处的字节偏移量 long declaredLength = fileDescriptor.getDeclaredLength();//返回构造AssetFileDescriptor时声明的实际字节数 //map方法来把文件影射为内存映像文件 把文件的从position开始的size大小的区域映射为内存映像文件, // mode指出了 可访问该内存映像文件的方式:READ_ONLY,READ_WRITE,PRIVATE return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);// }
- AssetFileDescriptor 资源文件描述器,用于读取文件中的数据
- JAVA中针对文件的读写操作设置了一系列的流,其中主要有FileInputStream、FileOutputStream、FileReader、FileWriter四种最为常用的流。
FileInputStream 流被称为文件字节输入流,意思指对文件数据以字节的形式进行读取操作如读取图片视频等- FileChannel 配合着ByteBuffer,将读写的数据缓存到内存中,然后以批量/缓存的方式read/write,
省去了非批量操作时的重复中间操作,操纵大文件时可以显著提高效率
详细的说:对于文件的复制,平时我们都是使用输入输出流进行操作,利用源文件创建出一个输入流,然后利用目标文件创建出一个输出流,最后将输入流的数据读取写入到输出流中。这样也是可以进行操作的。但是利用fileChannel是很有用的一个方式。它能直接连接输入输出流的文件通道,将数据直接写入到目标文件中去。而且效率更高。FileChannel 其实就是为了让文件快速复制,而获取的一个快速通道。- *.map 是把文件映射为内存映像文件
执行模型读取
public boolean load_model(Context context) { try { Interpreter.Options options = new Interpreter.Options(); options.setNumThreads(4); // 4线程运行 tflite = new Interpreter(loadModelFile(context), options); // 进行模型读取 return true; } catch (IOException e) { //e.printStackTrace(); return false; } }
1.3 读取图片并进行预处理
在Java中当我们要对数据进行更底层的操作时,一般是操作数据的字节(byte)形式,一个字节共8个二进制位,取值范围 [-128,127],或者[0, 255]。
【该函数的操作】:
- 分配ByteBuffer,size为神经网络输入图片的大小
- 将输入的图片 bitmap缩放,size为神经网络输入图片的大小
- 对缩放后的图片进行解析和预处理,然后放入ByteBuffer中。(其中预处理的操作过程如果封装到了tfile中,这里就不需要操作了)
private static ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) { //step1:每个像素点的三个分量都是一个int类型 java中int占4个字节 因此我们预先分配ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4 ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4); imgData.order(ByteOrder.nativeOrder()); // imgData的字节序按照当前机器使用的字节序 // get image pixel,size为图片的大小,这里是224x224 int[] pixels = new int[ddims[2] * ddims[3]]; //step2:将原图片按照ddims[2], ddims[3] 进行缩放 filter决定是否平滑 Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false); //将bm中的每个像素颜色转为int值存入pixels 每隔bm.getWidth()个像素换一行 每个值都是一个十进制 有时候还是负数 bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, ddims[2], ddims[3]); // val -15001066 tmp1 =27 tmp2=26 tmp3 = 22 int pixel = 0; for (int i = 0; i < ddims[2]; ++i) { for (int j = 0; j < ddims[3]; ++j) { final int val = pixels[pixel++]; // tmp1、tmp2、tmp3 是为了调试阶段,监测数值而定义的 int tmp1 = ((val >> 16) & 0xFF); int tmp2 = ((val >> 8) & 0xFF); int tmp3 = ((val ) & 0xFF); //预处理非常重要 imgData.putFloat(((val >> 16) & 0xFF) ); imgData.putFloat(((val >> 8) & 0xFF) ); imgData.putFloat((val & 0xFF) ); // imgData.putFloat(((((val >> 16) & 0xFF) - 128f) / 128f)); // imgData.putFloat(((((val >> 8) & 0xFF) - 128f) / 128f)); // imgData.putFloat((((val & 0xFF) - 128f) / 128f)) } } if (bm.isRecycled()) { bm.recycle(); } return imgData; }
1.4 预测图片分类 并后处理
【该函数的操作】:
- 获取输入图片的 ByteBuffer(并转换为 Object[] 的格式,用于多输入输出的预测API)
- 创建数组,来存放预测的结果(数组转化为Object[] 的格式,用于多输入输出的预测API)
- tfile.run() 或者 runForMultipleInputsOutputs()
- 解析预测结果
public String predict_image(Bitmap bmp) { boolean A = true; //自己测试使用的一个标志位 ByteBuffer inputData = getScaledMatrix(bmp, ddims); String show_text = "" ; try { if (A) { //单输入输出、tflite.run的使用 //原始模型输出的结果为10类 1行10列 属于10类的概率值 因此new一个float类型的数组用来存放run之后的结果 final float[][] labelProbArray = new float[1][10]; long start = System.currentTimeMillis(); tflite.run(inputData, labelProbArray);//运行模型 long end = System.currentTimeMillis(); long time = end - start; float[] results = new float[labelProbArray[0].length]; //为了方便计算 将第一行的10个概率值拷贝到一个一维数组results System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length); int r = get_max_result(results);//输出数据后处理 show_text = "result:" + r + "\nname:" + resultLabel.get(r) + "\nprobability:" + results[r] + "\ntime:" + time + "ms"; } else { //多输入输出、tflite.runForMultipleInputsOutputs的使用 //原始模型输出的结果为10类 1行10列 属于10类的概率值 因此new一个float类型的数组用来存放run之后的结果 //float[][] labelProbArray = new float[1][10]; final float[][] labelProbArray = new float[1][10]; //为输出数据建立与模型输出结果对应的变量 Object[] inputArray = {inputData}; Map<Integer, Object> outputMap = new HashMap(){ {put(0,labelProbArray);} }; long start = System.currentTimeMillis(); tflite.runForMultipleInputsOutputs(inputArray,outputMap); long end = System.currentTimeMillis(); long time = end - start; float[] results = new float[((float[][])outputMap.get(0))[0].length]; //为了方便计算 将第一行的10个概率值拷贝到一个一维数组results System.arraycopy(((float[][])outputMap.get(0))[0], 0, results, 0, ((float[][])outputMap.get(0))[0].le ngth); int r = get_max_result(results);//输出数据后处理 show_text = "result:" + r + "\nname:" + resultLabel.get(r) + "\nprobability:" + results[r] + "\ntime:" + time + "ms"; } } catch (Exception e) { e.printStackTrace(); } return show_text; }
解析预测结果:对于分类神经网络的输出,是预测各类的概率。只需要遍历,找出输出向量中最大值的索引即可
private int get_max_result(float[] result) { float probability = result[0]; int r = 0; //通过一个循环找到概率值最大的类别索引 for (int i = 0; i < result.length; i++) { if (probability < result[i]) { probability = result[i]; r = i; } } return r; }
1.5 模型释放与关闭
public void close(){ tflite.close(); }