一、模型保存
1、tf.train.saver
import tensorflow as tf
...
#在这里构建网络
...
#开始保存模型
与tf.Session()作为sess:
sess.run(tf.global_variables_initializer())#一定要先初始化整个流
#在这里训练网络
...
#保存参数
saver = tf.train.Saver()
saver.save(sess,PATH)#PATH就是要保存的路径
2、tf.saved_model.builder
将tensorflow import 为tf
...
#构建网络
...
用tf.Session()作为sess:
sess.run(tf.global_variables_initializer())#一定要先初始化整个流
#在这里训练网络
...
#保存参数
builder = tf.saved_model.builder.SaveModelBuilder(PATH)#PATH是保存路径
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING])#保存整张网络及其变量,这种方法是可以保存多张网络的,在此不作介绍,可自行了解
builder.save()#完成保存
二、模型调用
想要完整调用并使用一个训练好的模型,必须分为加载网络和加载关键节点两个部分。使用tf.saved_model.builder则可以完整的保存和调用模型的网络与节点
1、加载网络
将tensorflow import 为tf
用tf.Session(graph = tf.Graph())作为sess:
tf.saved_model.loader.load(sess,[tf.sa