TensorFlow之变量详解

 

我们已经知道了如何定义变量,我们再来回顾一下。

1、创建

创建一个变量时,将一个张量作为初始值传入构造函数Variable()。另外TensorFlow提供了一系列操作符来初始化张量,初始值是常量或是随机值。

# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")

注意:以上只是定义变量将要以哪个值来进行初始化,而还没有真正去初始化这个变量。所以需要用到下面初始化的方法去实现这个初始化动作。

 

2、初始化

变量的初始化必须在模型的其它操作运行之前先明确地完成。通常使用tf.global_variables_initializer()添加一个操作对变量做初始化。记得在完全构建好模型并加载之后再运行那个操作。

# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Later, when launching the model
with tf.Session() as sess:
  # Run the init operation.
  sess.run(init_op)
  ...
  # Use the model
  ...

 

3、由另一个变量初始化

有时候需要用另一个变量的初始化值给当前变量初始化,可以使用其它变量的initialized_value()属性:

# Create a variable with a random value.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
# Create another variable with the same value as 'weights'.
w2 = tf.Variable(weights.initialized_value(), name="w2")
# Create another variable with twice the value of 'weights'
w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")

 

4、保存变量

tf.train.Saver()创建一个Saver来管理模型中的所有变量。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print "Model saved in file: ", save_path

 

5、恢复变量

用同一个Saver对象来恢复变量。注意,当你从文件中恢复变量时,不需要事先对它们做初始化。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print "Model restored."
  # Do some work with the model
  ...

注意:如果不给tf.train.Saver()传入任何参数,那么saver将处理graph中的所有变量。其中每一个变量都以变量创建时传入的名称被保存。

 

6、选择存储和恢复哪些变量

可以通过给tf.train.Saver()构造函数传入Python字典,定义需要保持的变量及对应名称:键对应使用的名称,值对应被管理的变量

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore only 'v2' using the name "my_v2"
saver = tf.train.Saver({"my_v2": v2})
# Use the saver object normally after that.
...

注意:

  • 如果需要保存和恢复模型变量的不同子集,可以创建任意多个saver对象。同一个变量可被列入多个saver对象中,只有当saver的restore()函数被运行时,它的值才会发生改变。
  • 如果你仅在session开始时恢复模型变量的一个子集,你需要对剩下的变量执行初始化op。

 

 

参考:http://www.tensorfly.cn/tfdoc/how_tos/variables.html

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值