目前深度学习主要使用Python训练自己的模型,其中Keras提供了heigh-level语法,后端可采用Tensorflow或者Theano。
但是在实际应用时,大多数公司仍是使用java作为应用系统后台。于是便有了Python离线训练模型,Java调用模型实现在线预测。
Java调用Keras模型有两种方案,一种是基于Java的深度学习库DL4J导入Keras模型,另外一种是利用Tensorflow的java接口调用。DL4J目前暂不支持嵌套模型的导入,下面仅介绍第二种方案。
想要利用Tensorflow的java接口调用Keras模型,就需要将Keras保存的模型文件(.h5)转换为Tensorflow的模型文件(.pb)。
GitHub 已经有大神写了一个转换工具可以很方便的将Keras模型转Tensorflow模型,你只需要输入原模型文件的位置 和 目标模型文件的位置即可。
Keras的模型可以通过model.save() 方法保存为一个单独的模型文件(model.h5),此模型文件包含网络结构和权重参数。
此种模型可通过以下代码将Keras模型转换为Tensorflow模型:
python keras_to_tensorflow.py
--input_model="path/to/keras/model.h5"
--output_model="path/to/save/model.pb"
Keras的模型也可以通过 model.to_json() 和 model.save_weight() 来分开保存模型的结构(model.json)和 权重参数(weights.