1、tf2.1在keras中model的保存与调用
tf.keras.model类中的save_weights方法和load_weights方法保存模型的权重。
tf.keras.model.save方法可保存整个模型。保存的模型包括:
- The model architecture, allowing to re-instantiate the model.(模型的结构)
- The model weights.(模型的权重)
- The state of the optimizer, allowing to resume training exactly where you left off.(优化器的选择)
tf.keras.model.save方法允许在一个文件中将模型的全部状态记录下来。tf.keras.models.load_model可将保存下来的模型重建,并且支持直接使用。Models built with the Sequential and Functional API can be saved to both the HDF5 and SavedModel formats.
from keras.models import load_model
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
del model # deletes the existing model
# returns a compiled model
# identical to the previous one
model = load_model('my_model.h5')
HDF5文件数据格式参考文章:
1)http://docs.h5py.org/en/latest/index.html
2)https://blog.youkuaiyun.com/mzpmzk/article/details/89188968
2、什么是PB文件,保存为pb文件示例
PB文件表示MetaGraph的protocal buffer格式的文件,MetaGraph包括计算图,数据流,以及相关的变量和输入输出signature以及asserts指创建计算图时额外的文件。
谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。
示例代码如下
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os
from tensorflow.python.framework import graph_util
pb_file_path = os.getcwd()
with tf.Session(graph=tf.Graph()) as sess:
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
# 这里的输出需要加上name属性
op = tf.add(xy, b, name='op_to_store')
sess.run(tf.global_variables_initializer())
# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
# 测试 OP
feed_dict = {x: 10, y: 3}
print(sess.run(op, feed_dict))
# 写入序列化的 PB 文件
with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
# 输出
# INFO:tensorflow:Froze 1 variables.
# Converted 1 variables to const ops.
# 31
TensorFlow
在TensorFlow中,模型的持久化保存和加载主要通过Saver()。在初次训练之后调用如下的save函数保存,然后,在预测前,或者在继续训练前调用load加载参数即可。
def __init__():
self.sess = tf.Session()
# 定义好网络结构...
self.sess.run(tf.global_variables_initializer())
def check_path(self, path):
if not os.path.exists(path):
os.mkdir(path)
def save(self):
self.check_path('model')
saver=tf.train.Saver(tf.global_variables(),max_to_keep=10)
print("model: ",saver.save(self.sess,'model/modle.ckpt'))
def load(self):
saver=tf.train.Saver(tf.global_variables())
module_file = tf.train.latest_checkpoint('model')
saver.restore(self.sess, module_file)
参考文章:
【1】https://zhuanlan.zhihu.com/p/32887066
【2】https://blog.youkuaiyun.com/sunshinezhihuo/article/details/79705445