在训练深度网络时,为了减少需要训练参数的个数、或是多机多卡并行化训练大数据大模型等情况时,往往需要共享变量。另外一方面是当模型变得非常复杂的时候,往往存在大量的变量和操作,如何避免这些变量名和操作名的唯一不重复,同时维护一个条理清晰的graph非常重要。
本文主要涉及tensorflow中变量管理和命名空间相关的函数:tf.Variable,tf.get_variable,tf.variable_scope,tf.name_scope的使用。
1. tf.Variable和tf.get_variable
tensorflow可通过tf.Variable 函数来创建一个变量;通过tf.get_variable 函数来创建或者获取变量。当tf.get_variable 用于创建变量时,它和tf.Variable 的功能是基本等价的。如以下代码:
#下面这两个定义是等价的
v = tf.get variable ('v', shape=[1], initializer=tf.constant_initializer(1.0))
v = tf.Variable(tf.constant(1.0 , shape=[1]), name ='v')
tf.get_variable函数与tf.Variable函数最大的区别在于指定变量名称的参数。对于tf.Variable,变量名称是一个可选参数。但是对于tf.get_variable,变量名称是一个必填参数。tf.get_variable会根据这个名字去创建或者获取变量。在上述代码中, tf.get_variable首先会试图去创建一个名字为v的参数,如果创建失败(比如已经有同名的参数),那么这个程序就会报错。如果需要通过tf.get_variable获取一个已创建变量,需要通过tf.variable_scope函数来生成一个上下文管理器,并明确指定在这个上下文管理器中,tf.get_variable 将直接获取已经生成的变量。
# 在名字为foo的命名空间内创建名字为v的变量。
with tf.variable_scope('foo'):
v = tf.get_variable('v', [1], initializer=tf.constant_initializer(1.0))
# 因为在命名空间foo中已存在名字为v的变量,所以以下代码将会报错:
# Variable foo/v already exists, disallowed. Did you mean to set reuse=True in VarScope?
with tf.variable_scope('foo'):
v = tf.get_variable('v', [1])
# 在生成上下文管理器时,将参数reuse设置为True,这样tf.get_variable函数将直接获取已经声明的变量。
with tf.variable_scope('foo', reuse=True):
v1 = tf.get_variable('v', [1])
print(v == v1) #输出为True ,代表v, v1代表的是相同的TensorFlow变量。
# 将参数reuse设置为True时, tf.variable_scope 将只能获取巳经创建过的变量。
# 因为在命名空间bar中还没有创建变量v,所以以下代码将会报错:
# Variable bar/v does not exist, disallowed. Did you mean to set reuse=None in VarScope?
with tf. variable scope ('bar', reuse=True) :
v = tf.get_variable('v', [1])
2. tf.get_variable和tf.variable_scope
在以上代码中,已经可以看到,通过tf.variable_scope函数可以控制tf.get_variable的语义。当tf.variable_scope函数使用参数reuse=True生成上下文管理器时,这个上下文管理器内所有的tf.get_variable函数会直接获取已经创建的变量。如果变量不存在,则tf.get_variable函数将报错;相反,如果tf.variable_scope函数使用默认参数reuse=None或者reuse=False创建上下文管理器,tf.get_variable操作将创建新的变量。如果同名的变量已经存在,则tf.get_variable函数将报错。
tf.variable_scope 函数生成的上下文管理器也会创建一个TensorFlow中的命名空间,在命名空间内创建的变量名称都会带上这个命名空间名作为前缀。所以,tf.variable_scope除了可以控制tf.get_variable执行的功能,这个函数也提供了一个管理变量命名空间的方式。如以下代码:
v1 = tf.get_variable('v', [1])
print(v1.name) #输出v:0,'v'为变量的名称,':0'表示该变量是生成变量这一运算的第一个结果
with tf.variable_scope('foo'):
v2 = tf.get_variable('v', [1])
print(v2.name) #输出foo/v:0,通过/来分隔命名空间的名称和变量的名称。
3. tf.variable_scope,tf.name_scope
除了tf.variable_scope函数,tf.name_scope函数也提供了命名空间管理的功能。这两个函数在大部分情况下是等价的,唯一的区别是在使用tf.get_variable函数时。如下代码:
with tf.variable_scope('foo'):
a = tf.get_variable('bar', [1])
print(a.name) #输出: foo/bar:0
with tf.variable_scope('bar'):
b = tf.get_variable('bar', [1])
print(b.name) #输出: bar/bar:0,与命名空间foo下的同名变量不冲突
with tf.name_scope('s1'):
a = tf.Variable([1], name='a')
print(a.name) #输出s1/a:0。tf.Variable生成的变量会受f.name_scope影响
b = tf.get_variable('b', [1])
print(b.name) #输出b:0。tf.get_variable生成的变量不受f.name_scope影响
c = 1 + a
print(c.name) #输出s1/add:0。op的输出会受f.name_scope影响,返回的c是tf.Tensor对象
d = 1 + b
print(d.name) #输出s1/add_1:0。op的输出会受tf.name_scope影响,且Tensor重名会自动处理,与以下的tf.Variable相似。
e = tf.Variable([1], name='a')
print(e.name) #输出s1/a_1:0。定义与a一样,重名会自动处理
with tf.name_scope('s1'):
f = tf.Variable([1], name='a')
print(e.name) #输出s1_1/a:0。定义与s1/a一样,重名会自动处理
在使用TensorBoard进行可视化时,用TensorFlow命名空间进行封装能更加清晰地呈现计算图信息。在TensorBoard的默认视图中,TensorFlow计算图中同一个命名空间下的所有节点会被缩略成一个节点,只有顶层命名空间中的节点才会被显示在TensorBoard可视化效果图上。
4. 小结
- tf.Variable(以及tf.Tensor)会自动检测命名冲突并自行处理。对于使用tf.Variable来说,tf.name_scope和tf.variable_scope功能一样,都是给变量加前缀,相当于分类管理,模块化。
- tf.get_variable用于创建/获取一个变量,并且不受name_scope的约束。根据是否设置为共享变量,如果未设置共享变量,若变量不存在,则创建变量;当变量已存在,则报错;如果设置了共享变量,若变量不存在,则报错;当变量已存在,则获取变量。
- tf.get_variable和tf.variable_scope必须要搭配使用,为变量共享提供支持。
- 如果想实现“有则共享,无则新建”的方式,可以使用with tf.variable_scope(’’, reuse=tf.AUTO_REUSE)
- tf.name_scope主要用于管理一个图里面的各种op,返回的是一个以scope_name命名的context manager。一个graph会维护一个namespace的堆,每一个namespace下面可以定义各种op或者子namespace,实现一种层次化有条理的管理,避免各个op之间命名冲突。
- tf.variable_scope一般与tf.name_scope配合使用,用于管理一个graph中变量的名字,避免变量之间的命名冲突。同时用于允许在一个variable_scope下面共享变量。
- 当多个tf_variable_scope嵌套时,如果中间某层开启了reuse=True, 则内层自动全部共享,即使内层设置了reuse=False。而且,一旦使用tf.get_variable_scope().reuse_variables()打开了当前域共享,就不能关闭了!
最后放上一张图:

5. 参考
- 《TensorFlow实战Google深度学习框架》
- tensorflow里面name_scope, variable_scope等如何理解?