前述
最近在做一个视觉方面的Demo。坑当然是多到不行,想到这都是了解生态的一个过程,也就不那么烦躁。我们的模型训练部分往往是用Python写Keras或者直接上TensorFlow,然后得到model。但部署这件事还没听说直接用Python就能解决,大多需要别的工具。
第一种方式是通过网络,以服务器、客户端的形式实现。这时候可以写个简单的Flask接口就可以实现建议的模型部署,稍复杂、专业一些就可以用到TF Serving之类的专门的部署工具。很容易理解,这种方式使用模型的服务必须联网,由于是一些视觉方面的应用,对网络的要求可能还比较高。
第二种方式是本地化部署,将模型打包在App中直接在本地调用。App一般情况下都不是Python开发,更可能是JS、Swift、C++、Java等其他语言(TF支持JS、Swift、C/C++、Java、Go等等)。这种方式的最大缺陷就是受到设备计算资源的限制,但在我的Demo中勉强能够使用。最终也是选择了这种方式。
部署原理
详细来说,这篇是使用Java调用Python训练的模型。首先第一个坑,Java目前好像只支持TensorFlow1,所以Python训练也不能使用TensorFlow2。简单来说,这个部署过程就是将h5、checkpoint格式的model转成pb格式的model。在Java中只能读取pb格式的model。
格式转换工具
转换格式一般采用Python脚本,推荐使用pyenv保持同Java TF版本一致的Python版本及TF版本(曾因版本不同出现过莫名的问题)。关于TF的安装
如果使用Keras,推荐使用keras_to_tensorflow。
# ReadMe中有使用方法
# python keras_to_tensorflow.py
# --input_model="path/to/keras/model.h5"
# --input_model_json="path/to/keras/model.json"
# --output_model="path/to/save/model.pb"
# h5文件通过save_weights保存
model.save_weights('model.h5')
# json文件通过to_json得到
with open("model.json", "w") as json_file:
json_file.write(model.to_json())如果使用TensorFlow,可以使用tf-ckpt-2-pb。
# ReadMe中的Usage
# python convert.py
# --checkpoint "path/to/tf/ckpt_weight"
# --model "path/to/tf/ckpt_weight/model.ckpt.meta"
# --out-path "path/to/save/out.pb"
# checkpoint是ckpt文件夹
tf.train.Saver().save(session, path)
# mdoel是ckpt文件夹中的meta文件找出model的input和output
Java调用TF model需要知道Input layer name和Output layer name,这时候可能需要使用tensorboard工具,自己去看网络结构。
import tensorflow as tf
from tensorflow.summary import FileWriter
sess = tf.Session()
tf.train.import_meta_graph("path/to/tf/ckpt_weight/model.ckpt.meta")
FileWriter("__tb", sess.graph)
# after run python script,
# run cmd: tensorboard --logdir __tb只出不入的大概就是Input layer,只入不出的很可能就是Output layer。唯一需要注意的是将名称写全,比如有仅一层的名字input,也有几层的名字generator/MODEL/outLayer。
Java调用TensorFlow的方法
得到pb文件、Input layer name和Output layer name就只差写Java代码调用啦。安装 Java 版 TensorFlow
这时候按照官方教程走,应该不会出什么问题。教程中现在用的是1.14.0的版本,所以python中也最好用1.14.0版本的TF。版本问题前面就已经提到过,不多说。
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.14.0</version>
</dependency>简易的载入模型和预测函数及使用:
# TensorFlowUtils.java
public final class TensorFlowUtils {
public static Session loadModel(String modelPath, Class<?> cls) {
try {
Graph graph = new Graph();
Session session = new Session(graph);
graph.importGraphDef(IOUtils.toByteArray(cls.getResourceAsStream(modelPath)));
return session;
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
public static void closeModel(Session session) {
session.close();
}
public static Tensor<Float> predict(Tensor<Float> input, Session session, String inputName, String outputName) {
return session.runner().feed(inputName, input).fetch(outputName).run().get(0).expect(Float.class);
}
}
# Main.java
public class Main {
public static void main(String[] args) {
Session session = TensorFlowUtils.loadModel("model/path", getClass());
// float[][][][] originInput = may be an image, or others
Tensor<Float> input = Tensor.create(originInput, Float.class);
Tensor<Float> output = TensorFlowUtils.predict(input, session,
"<your input name>", "<your output name>");
float[][] result = output.copyTo(new float[1][2]);
System.out.println(result[0]);
TensorFlowUtils.closeModel(session);
}
}
本文围绕视觉Demo中Java调用Python训练的TensorFlow模型展开。介绍了模型部署的两种方式,选择本地化部署。阐述部署原理,需将模型转为pb格式。还提及格式转换工具、找出model的input和output的方法,最后说明Java调用TensorFlow的方法及版本要求。
2100

被折叠的 条评论
为什么被折叠?



