训练模型时,用变量来存储和更新参数。变量包含张量存放于内存的缓存区,建模时它们需要被明确地初始化,模型训练后它们必须被存储到磁盘。这些变量的值可在之后模型训练和分析时被加载。
一、变量的创建
创建变量时,将一个张量作为初始值传入构造函数Variable( )。初始值一般是常量或是随机值。
下面各举初始值为常量、随机值的例子。
#初始值是随机量
weights=tf.Variable(tf.random_normal([784,200],stddev=0.35),name="weights")
#初始值是常量
biases=tf.Variable(tf.zeros([200]),name="biases")
更多的初始化张量的操作符,可以参见TF官方文档。(PS:官方文档永远是最好的学习资料)
二、变量的初始化
变量的初始化必须在模型的其他操作(OP)运行之前明确地完成。变量的初始化方法有两种:1.添加一个给所有变量初始化的OP,并在使用模型之前运行此OP。2.用另一个变量的初始化值给当前变量初始化。
#添加一个给所有变量初始化的OP,并在使用模型之前运行此OP
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
...
#用另一个变量的初始化值给当前变量初始化(用一中所定义的weights作为前提条件)
w2=tf.Variable(weights.initialized_value(),name="w2")
三、变量的存储
用tf.train.Saver()创建一个Saver来管理模型中的所有变量.
以下是来自官方文档的一个例子
#Creat a saver
saver=tf.train.Saver(... Variables ...)
#Launch the graph and train, saving the model every 1,000 steps.
sess=tf.Session()
for step in xrange(100000):
sess.run(...training_op...)
if step%1000==0:
# Append the step number to the checkpoint name:
saver.save(sess,'my-model',global_step=step)
四、恢复变量
用同一个Saver对象来恢复变量。(当从文件中恢复变量时,不需要再次初始化对象)
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print "Model restored."
# Do some work with the model
注意:
1.如果需要保存和恢复模型变量的不同子集,可以创建任意多个saver对象。同一个变量可被列入多个saver对象中,只有当saver的restore()函数被运行时,它的值才会发生改变。
2.如果你仅在session开始时恢复模型变量的一个子集,你需要对剩下的变量执行初始化op。