1,基本内容
目的是将模型数据以文件的形式保存到本地。
使用神经网络模型进行大数据量和复杂模型训练时,训练时间可能会持续增加,此时为避免训练过程出现不可逆的影响,并验证训练效果,可以考虑分段进行,将训练数据模型保存,然后在继续训练时重新读取; 此外,模型训练完毕,获取一个性能良好的模型后,可以保存以备重复利用。
2,参数保存和读取代码
import tensorflow as tf
#随机初始化两个变量
v1 = tf.Variable(tf.random_normal([1,2]), name="v1")#矩阵大小为[1,2]
v2 = tf.Variable(tf.random_normal([2,4]), name="v2")#矩阵大小为[2,4]
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()#定义该类的一个对象
with tf.Session() as sess:
sess.run(init_op)
print ("V1:",sess.run(v1))
print ("V2:",sess.run(v2))
saver_path = saver.save(sess, "Save/model.ckpt")#保存sess计算域中所有的参数值
print ("Model saved")
saver.restore(sess, "Save/model.ckpt")#读取保存的文件
print ("V1_1:",sess.run(v1))
print ("V2_1:",sess.run(v2))
print ("Model restored")
运行结果:
2,网络模型的保存与读取代码
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.trai