将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

本文介绍如何将TensorFlow训练的MNIST手写数字识别模型移植到Android平台,包括模型训练、保存及在Android上的加载和使用。
部署运行你感兴趣的模型镜像

将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

【转载】https://blog.youkuaiyun.com/guyuealian/article/details/79672257

项目Github下载地址:https://github.com/PanJinquan/Mnist-tensorFlow-AndroidDemo麻烦给个“star”哈

       本博客将以最简单的方式,利用TensorFlow实现了MNIST手写数字识别,并将Python TensoFlow训练好的模型移植到Android手机上运行。网上也有很多移植教程,大部分是在Ubuntu(Linux)系统,一般先利用Bazel工具把TensoFlow编译成.so库文件和jar包,再进行Android配置,实现模型移植。不会使用Bazel也没关系,实质上TensoFlow已经为开发者提供了最新的.so库文件和对应的jar包了(如libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar),我们只需要下载文件,并在本地Android Studio导入jar包和.so库文件,即可以在Android加载TensoFlow的模型了。 

      当然了,本博客的项目代码都上传到Githubhttps://github.com/PanJinquan/Mnist-tensorFlow-AndroidDemo

      先说一下,本人的开发环境:

  • Windows 7
  • Python3.5
  • TensoFlow 1.6.0(2018年3月23日—当前最新版)
  • Android Studio 3.0.1(2018年3月23日—当前最新版)

一、利用Python训练模型

   以MNIST手写数字识别为例,这里首先使用Python版的TensorFlow实现单隐含层的SoftMax Regression分类器,并将训练好的模型的网络拓扑结构和参数保存为pb文件。首先,需要定义模型的输入层和输出层节点的名字(通过形参 'name'指定,名字可以随意,后面加载模型时,都是通过该name来传递数据的):


 
  1. x = tf.placeholder(tf.float32,[ None, 784],name= 'x_input') #输入节点:x_input
  2. .
  3. .
  4. .
  5. pre_num=tf.argmax(y, 1,output_type= 'int32',name= "output") #输出节点:output

PS:说一下鄙人遇到坑:起初,我参照网上相关教程训练了一个模型,在Windows下测试没错,但把模型移植到Android后就出错了,但用别人的模型又正常运行;后来折腾了半天才发现,是类型转换出错啦!!!!
TensorFlow默认类型是float32,但我们希望返回的是一个int型,因此需要指定output_type='int32';但注意了,在Windows下测试使用int64和float64都是可以的,但在Android平台上只能使用int32和float32,并且对应Java的int和float类型。

 将训练好的模型保存为.pb文件,这就需要用到tf.graph_util.convert_variables_to_constants函数了。


 
  1. # 保存训练好的模型
  2. #形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
  3. output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=[ 'output'])
  4. with tf.gfile.FastGFile( 'model/mnist.pb', mode= 'wb') as f: #’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
  5. f.write(output_graph_def.SerializeToString())

   关于tensorflow保存模型和加载模型的方法,请参考本人另一篇博客:https://blog.youkuaiyun.com/guyuealian/article/details/79693741

   这里给出Python训练模型完整的代码如下:


 
  1. #coding=utf-8
  2. # 单隐层SoftMax Regression分类器:训练和保存模型模块
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. import tensorflow as tf
  5. from tensorflow.python.framework import graph_util
  6. print( 'tensortflow:{0}'.format(tf.__version__))
  7. mnist = input_data.read_data_sets( "Mnist_data/", one_hot= True)
  8. #create model
  9. with tf.name_scope( 'input'):
  10. x = tf.placeholder(tf.float32,[ None, 784],name= 'x_input') #输入节点名:x_input
  11. y_ = tf.placeholder(tf.float32,[ None, 10],name= 'y_input')
  12. with tf.name_scope( 'layer'):
  13. with tf.name_scope( 'W'):
  14. #tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
  15. W = tf.Variable(tf.zeros([ 784, 10]),name= 'Weights')
  16. with tf.name_scope( 'b'):
  17. b = tf.Variable(tf.zeros([ 10]),name= 'biases')
  18. with tf.name_scope( 'W_p_b'):
  19. Wx_plus_b = tf.add(tf.matmul(x, W), b, name= 'Wx_plus_b')
  20. y = tf.nn.softmax(Wx_plus_b, name= 'final_result')
  21. # 定义损失函数和优化方法
  22. with tf.name_scope( 'loss'):
  23. loss = -tf.reduce_sum(y_ * tf.log(y))
  24. with tf.name_scope( 'train_step'):
  25. train_step = tf.train.GradientDescentOptimizer( 0.01).minimize(loss)
  26. print(train_step)
  27. # 初始化
  28. sess = tf.InteractiveSession()
  29. init = tf.global_variables_initializer()
  30. sess.run(init)
  31. # 训练
  32. for step in range( 100):
  33. batch_xs,batch_ys =mnist.train.next_batch( 100)
  34. train_step.run({x:batch_xs,y_:batch_ys})
  35. # variables = tf.all_variables()
  36. # print(len(variables))
  37. # print(sess.run(b))
  38. # 测试模型准确率
  39. pre_num=tf.argmax(y, 1,output_type= 'int32',name= "output") #输出节点名:output
  40. correct_prediction = tf.equal(pre_num,tf.argmax(y_, 1,output_type= 'int32'))
  41. accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  42. a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})
  43. print( '测试正确率:{0}'.format(a))
  44. # 保存训练好的模型
  45. #形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),
  46. output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=[ 'output'])
  47. with tf.gfile.FastGFile( 'model/mnist.pb', mode= 'wb') as f: #’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
  48. f.write(output_graph_def.SerializeToString())
  49. sess.close()

上面的代码已经将训练模型保存在model/mnist.pb,当然我们可以先在Python中使用该模型进行简单的预测,测试方法如下:


 
  1. import tensorflow as tf
  2. import numpy as np
  3. from PIL import Image
  4. import matplotlib.pyplot as plt
  5. #模型路径
  6. model_path = 'model/mnist.pb'
  7. #测试图片
  8. testImage = Image.open( "data/test_image.jpg");
  9. with tf.Graph().as_default():
  10. output_graph_def = tf.GraphDef()
  11. with open(model_path, "rb") as f:
  12. output_graph_def.ParseFromString(f.read())
  13. tf.import_graph_def(output_graph_def, name= "")
  14. with tf.Session() as sess:
  15. tf.global_variables_initializer().run()
  16. # x_test = x_test.reshape(1, 28 * 28)
  17. input_x = sess.graph.get_tensor_by_name( "input/x_input:0")
  18. output = sess.graph.get_tensor_by_name( "output:0")
  19. #对图片进行测试
  20. testImage=testImage.convert( 'L')
  21. testImage = testImage.resize(( 28, 28))
  22. test_input=np.array(testImage)
  23. test_input = test_input.reshape( 1, 28 * 28)
  24. pre_num = sess.run(output, feed_dict={input_x: test_input}) #利用训练好的模型预测结果
  25. print( '模型预测结果为:',pre_num)
  26. #显示测试的图片
  27. # testImage = test_x.reshape(28, 28)
  28. fig = plt.figure(), plt.imshow(testImage,cmap= 'binary') # 显示图片
  29. plt.title( "prediction result:"+str(pre_num))
  30. plt.show()

二、移植到Android

    相信大家看到很多大神的博客,都是要自己编译TensoFlow的so库和jar包,说实在的,这个过程真TM麻烦,反正我弄了半天都没成功过,然后放弃了……。本博客的移植方法不需要安装Bazel,也不需要构建TensoFlow的so库和jar包,因为Google在TensoFlow github中给我们提供了,为什么不用了!!!

1、下载TensoFlow的jar包和so库

    TensoFlow在Github已经存放了很多开发文件:https://github.com/PanJinquan/tensorflow

   我们需要做的是,下载Android: native libs ,打包下载全部文件,其中有我们需要的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar,有了这两个文件,剩下的就是在Android Studio配置的问题了

2、Android Studio配置

(1)新建一个Android项目

(2)把训练好的pb文件(mnist.pb)放入Android项目中app/src/main/assets下,若不存在assets目录,右键main->new->Directory,输入assets。

(3)将下载的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar如下结构放在libs文件夹下

 

(4)app\build.gradle配置

    在defaultConfig中添加


 
  1. multiDexEnabled true
  2. ndk {
  3. abiFilters "armeabi-v7a"
  4. }

    增加sourceSets


 
  1. sourceSets {
  2. main {
  3. jniLibs.srcDirs = ['libs']
  4. }
  5. }

    在dependencies中增加TensoFlow编译的jar文件libandroid_tensorflow_inference_java.jar:

    compile files('libs/libandroid_tensorflow_inference_java.jar')

 

   OK了,build.gradle配置完成了,剩下的就是java编程的问题了。

3、模型调用

  在需要调用TensoFlow的地方,加载so库“System.loadLibrary("tensorflow_inference");并”import org.tensorflow.contrib.android.TensorFlowInferenceInterface;就可以使用了

     注意,旧版的TensoFlow,是如下方式进行,该方法可参考大神的博客:https://www.jianshu.com/p/1168384edc1e


 
  1. TensorFlowInferenceInterface.fillNodeFloat(); //送入输入数据
  2. TensorFlowInferenceInterface.runInference(); //进行模型的推理
  3. TensorFlowInferenceInterface.readNodeFloat(); //获取输出数据

     但在最新的libandroid_tensorflow_inference_java.jar中,已经没有这些方法了,换为


 
  1. TensorFlowInferenceInterface.feed()
  2. TensorFlowInferenceInterface.run()
  3. TensorFlowInferenceInterface.fetch()

     下面是以MNIST手写数字识别为例,其实现方法如下:


 
  1. package com.example.jinquan.pan.mnist_ensorflow_androiddemo;
  2. import android.content.res.AssetManager;
  3. import android.graphics.Bitmap;
  4. import android.graphics.Color;
  5. import android.graphics.Matrix;
  6. import android.util.Log;
  7. import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
  8. public class PredictionTF {
  9. private static final String TAG = "PredictionTF";
  10. //设置模型输入/输出节点的数据维度
  11. private static final int IN_COL = 1;
  12. private static final int IN_ROW = 28* 28;
  13. private static final int OUT_COL = 1;
  14. private static final int OUT_ROW = 1;
  15. //模型中输入变量的名称
  16. private static final String inputName = "input/x_input";
  17. //模型中输出变量的名称
  18. private static final String outputName = "output";
  19. TensorFlowInferenceInterface inferenceInterface;
  20. static {
  21. //加载libtensorflow_inference.so库文件
  22. System.loadLibrary( "tensorflow_inference");
  23. Log.e(TAG, "libtensorflow_inference.so库加载成功");
  24. }
  25. PredictionTF(AssetManager assetManager, String modePath) {
  26. //初始化TensorFlowInferenceInterface对象
  27. inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
  28. Log.e(TAG, "TensoFlow模型文件加载成功");
  29. }
  30. /**
  31. * 利用训练好的TensoFlow模型预测结果
  32. * @param bitmap 输入被测试的bitmap图
  33. * @return 返回预测结果,int数组
  34. */
  35. public int[] getPredict(Bitmap bitmap) {
  36. float[] inputdata = bitmapToFloatArray(bitmap, 28, 28); //需要将图片缩放带28*28
  37. //将数据feed给tensorflow的输入节点
  38. inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);
  39. //运行tensorflow
  40. String[] outputNames = new String[] {outputName};
  41. inferenceInterface.run(outputNames);
  42. ///获取输出节点的输出信息
  43. int[] outputs = new int[OUT_COL*OUT_ROW]; //用于存储模型的输出数据
  44. inferenceInterface.fetch(outputName, outputs);
  45. return outputs;
  46. }
  47. /**
  48. * 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。
  49. * @param bitmap 输入被测试的bitmap图片
  50. * @param rx 将图片缩放到指定的大小(列)->28
  51. * @param ry 将图片缩放到指定的大小(行)->28
  52. * @return 返回归一化后的一维float数组 ->28*28
  53. */
  54. public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry){
  55. int height = bitmap.getHeight();
  56. int width = bitmap.getWidth();
  57. // 计算缩放比例
  58. float scaleWidth = (( float) rx) / width;
  59. float scaleHeight = (( float) ry) / height;
  60. Matrix matrix = new Matrix();
  61. matrix.postScale(scaleWidth, scaleHeight);
  62. bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
  63. Log.i(TAG, "bitmap width:"+bitmap.getWidth()+ ",height:"+bitmap.getHeight());
  64. Log.i(TAG, "bitmap.getConfig():"+bitmap.getConfig());
  65. height = bitmap.getHeight();
  66. width = bitmap.getWidth();
  67. float[] result = new float[height*width];
  68. int k = 0;
  69. //行优先
  70. for( int j = 0;j < height;j++){
  71. for ( int i = 0;i < width;i++){
  72. int argb = bitmap.getPixel(i,j);
  73. int r = Color.red(argb);
  74. int g = Color.green(argb);
  75. int b = Color.blue(argb);
  76. int a = Color.alpha(argb);
  77. //由于是灰度图,所以r,g,b分量是相等的。
  78. assert(r==g && g==b);
  79. // Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);
  80. result[k++] = r / 255.0f;
  81. }
  82. }
  83. return result;
  84. }
  85. }
  • 简单说明一下:项目新建了一个PredictionTF类,该类会先加载libtensorflow_inference.so库文件;PredictionTF(AssetManager assetManager, String modePath) 构造方法需要传入AssetManager对象和pb文件的路径;
  •  从资源文件中获取BitMap图片,并传入 getPredict(Bitmap bitmap)方法,该方法首先将BitMap图像缩放到28*28的大小,由于原图是灰度图,我们需要获取灰度图的像素值,并将28*28的像素转存为行向量的一个float数组,并且每个像素点都归一化到0~1之间,这个就是bitmapToFloatArray(Bitmap bitmap, int rx, int ry)方法的作用;
  •  然后将数据feed给tensorflow的输入节点,并运行(run)tensorflow,最后获取(fetch)输出节点的输出信息。

   MainActivity很简单,一个单击事件获取预测结果:


 
  1. package com.example.jinquan.pan.mnist_ensorflow_androiddemo;
  2. import android.graphics.Bitmap;
  3. import android.graphics.BitmapFactory;
  4. import android.support.v7.app.AppCompatActivity;
  5. import android.os.Bundle;
  6. import android.util.Log;
  7. import android.view.View;
  8. import android.widget.ImageView;
  9. import android.widget.TextView;
  10. public class MainActivity extends AppCompatActivity {
  11. // Used to load the 'native-lib' library on application startup.
  12. static {
  13. System.loadLibrary( "native-lib"); //可以去掉
  14. }
  15. private static final String TAG = "MainActivity";
  16. private static final String MODEL_FILE = "file:///android_asset/mnist.pb"; //模型存放路径
  17. TextView txt;
  18. TextView tv;
  19. ImageView imageView;
  20. Bitmap bitmap;
  21. PredictionTF preTF;
  22. @Override
  23. protected void onCreate(Bundle savedInstanceState) {
  24. super.onCreate(savedInstanceState);
  25. setContentView(R.layout.activity_main);
  26. // Example of a call to a native method
  27. tv = (TextView) findViewById(R.id.sample_text);
  28. txt=(TextView)findViewById(R.id.txt_id);
  29. imageView =(ImageView)findViewById(R.id.imageView1);
  30. bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.test_image);
  31. imageView.setImageBitmap(bitmap);
  32. preTF = new PredictionTF(getAssets(),MODEL_FILE); //输入模型存放路径,并加载TensoFlow模型
  33. }
  34. public void click01(View v){
  35. String res= "预测结果为:";
  36. int[] result= preTF.getPredict(bitmap);
  37. for ( int i= 0;i<result.length;i++){
  38. Log.i(TAG, res+result[i] );
  39. res=res+String.valueOf(result[i])+ " ";
  40. }
  41. txt.setText(res);
  42. tv.setText(stringFromJNI());
  43. }
  44. /**
  45. * A native method that is implemented by the 'native-lib' native library,
  46. * which is packaged with this application.
  47. */
  48. public native String stringFromJNI(); //可以去掉
  49. }

   activity_main布局文件:


 
  1. <?xml version="1.0" encoding="utf-8"?>
  2. <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
  3. android:layout_width= "match_parent"
  4. android:layout_height= "match_parent"
  5. android:orientation= "vertical"
  6. android:paddingBottom= "16dp"
  7. android:paddingLeft= "16dp"
  8. android:paddingRight= "16dp"
  9. android:paddingTop= "16dp">
  10. <TextView
  11. android:id= "@+id/sample_text"
  12. android:layout_width= "wrap_content"
  13. android:layout_height= "wrap_content"
  14. android:text= "https://blog.youkuaiyun.com/guyuealian"
  15. android:layout_gravity= "center"/>
  16. <Button
  17. android:onClick= "click01"
  18. android:layout_width= "match_parent"
  19. android:layout_height= "wrap_content"
  20. android:text= "click" />
  21. <TextView
  22. android:id= "@+id/txt_id"
  23. android:layout_width= "match_parent"
  24. android:layout_height= "wrap_content"
  25. android:gravity= "center"
  26. android:text= "结果为:"/>
  27. <ImageView
  28. android:id= "@+id/imageView1"
  29. android:layout_width= "wrap_content"
  30. android:layout_height= "wrap_content"
  31. android:layout_gravity= "center"/>
  32. </LinearLayout>

最后一步,就是run,run,run,效果如下, 

博客的项目代码都上传到Github:下载地址:https://github.com/PanJinquan/Mnist-tensorFlow-AndroidDemo

 

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值