tensorflow变量管理与命名空间

本文深入探讨TensorFlow中变量管理及命名空间的使用,包括tf.Variable与tf.get_variable的区别,tf.variable_scope与tf.name_scope的功能,以及它们在变量创建、获取和命名冲突处理中的作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在训练深度网络时,为了减少需要训练参数的个数、或是多机多卡并行化训练大数据大模型等情况时,往往需要共享变量。另外一方面是当模型变得非常复杂的时候,往往存在大量的变量和操作,如何避免这些变量名和操作名的唯一不重复,同时维护一个条理清晰的graph非常重要。

本文主要涉及tensorflow中变量管理和命名空间相关的函数:tf.Variable,tf.get_variable,tf.variable_scope,tf.name_scope的使用。

1. tf.Variable和tf.get_variable

       tensorflow可通过tf.Variable 函数来创建一个变量;通过tf.get_variable 函数来创建或者获取变量。当tf.get_variable 用于创建变量时,它和tf.Variable 的功能是基本等价的。如以下代码:

#下面这两个定义是等价的
v = tf.get variable ('v', shape=[1], initializer=tf.constant_initializer(1.0))
v = tf.Variable(tf.constant(1.0 , shape=[1]), name ='v')

       tf.get_variable函数与tf.Variable函数最大的区别在于指定变量名称的参数。对于tf.Variable,变量名称是一个可选参数。但是对于tf.get_variable,变量名称是一个必填参数。tf.get_variable会根据这个名字去创建或者获取变量。在上述代码中, tf.get_variable首先会试图去创建一个名字为v的参数,如果创建失败(比如已经有同名的参数),那么这个程序就会报错。如果需要通过tf.get_variable获取一个已创建变量,需要通过tf.variable_scope函数来生成一个上下文管理器,并明确指定在这个上下文管理器中,tf.get_variable 将直接获取已经生成的变量。

# 在名字为foo的命名空间内创建名字为v的变量。
with tf.variable_scope('foo'):
	v = tf.get_variable('v', [1], initializer=tf.constant_initializer(1.0))
	
# 因为在命名空间foo中已存在名字为v的变量,所以以下代码将会报错:
# Variable foo/v already exists, disallowed. Did you mean to set reuse=True in VarScope?
with tf.variable_scope('foo'):
	v = tf.get_variable('v', [1])
	
# 在生成上下文管理器时,将参数reuse设置为True,这样tf.get_variable函数将直接获取已经声明的变量。
with tf.variable_scope('foo', reuse=True):
	v1 = tf.get_variable('v', [1])
	print(v == v1) #输出为True ,代表v, v1代表的是相同的TensorFlow变量。

# 将参数reuse设置为True时, tf.variable_scope 将只能获取巳经创建过的变量。
# 因为在命名空间bar中还没有创建变量v,所以以下代码将会报错:
# Variable bar/v does not exist, disallowed. Did you mean to set reuse=None in VarScope?
with tf. variable scope ('bar', reuse=True) :
	v = tf.get_variable('v', [1])

2. tf.get_variable和tf.variable_scope

       在以上代码中,已经可以看到,通过tf.variable_scope函数可以控制tf.get_variable的语义。当tf.variable_scope函数使用参数reuse=True生成上下文管理器时,这个上下文管理器内所有的tf.get_variable函数会直接获取已经创建的变量。如果变量不存在,则tf.get_variable函数将报错;相反,如果tf.variable_scope函数使用默认参数reuse=None或者reuse=False创建上下文管理器,tf.get_variable操作将创建新的变量。如果同名的变量已经存在,则tf.get_variable函数将报错。

       tf.variable_scope 函数生成的上下文管理器也会创建一个TensorFlow中的命名空间,在命名空间内创建的变量名称都会带上这个命名空间名作为前缀。所以,tf.variable_scope除了可以控制tf.get_variable执行的功能,这个函数也提供了一个管理变量命名空间的方式。如以下代码:

v1 = tf.get_variable('v', [1])
print(v1.name) #输出v:0,'v'为变量的名称,':0'表示该变量是生成变量这一运算的第一个结果

with tf.variable_scope('foo'):
	v2 = tf.get_variable('v', [1])
	print(v2.name) #输出foo/v:0,通过/来分隔命名空间的名称和变量的名称。

3. tf.variable_scope,tf.name_scope

       除了tf.variable_scope函数,tf.name_scope函数也提供了命名空间管理的功能。这两个函数在大部分情况下是等价的,唯一的区别是在使用tf.get_variable函数时。如下代码:

with tf.variable_scope('foo'):
	a = tf.get_variable('bar', [1])
	print(a.name)  #输出: foo/bar:0
	
with tf.variable_scope('bar'):
	b = tf.get_variable('bar', [1])
	print(b.name)  #输出: bar/bar:0,与命名空间foo下的同名变量不冲突

with tf.name_scope('s1'):
	a = tf.Variable([1], name='a')
	print(a.name)  #输出s1/a:0。tf.Variable生成的变量会受f.name_scope影响
	
	b = tf.get_variable('b', [1])
	print(b.name)  #输出b:0。tf.get_variable生成的变量不受f.name_scope影响

	c = 1 + a
	print(c.name) #输出s1/add:0。op的输出会受f.name_scope影响,返回的c是tf.Tensor对象

	d = 1 + b
	print(d.name) #输出s1/add_1:0。op的输出会受tf.name_scope影响,且Tensor重名会自动处理,与以下的tf.Variable相似。

	e = tf.Variable([1], name='a')
	print(e.name) #输出s1/a_1:0。定义与a一样,重名会自动处理

with tf.name_scope('s1'):
	f = tf.Variable([1], name='a')
	print(e.name) #输出s1_1/a:0。定义与s1/a一样,重名会自动处理

       在使用TensorBoard进行可视化时,用TensorFlow命名空间进行封装能更加清晰地呈现计算图信息。在TensorBoard的默认视图中,TensorFlow计算图中同一个命名空间下的所有节点会被缩略成一个节点,只有顶层命名空间中的节点才会被显示在TensorBoard可视化效果图上。

4. 小结

  1. tf.Variable(以及tf.Tensor)会自动检测命名冲突并自行处理。对于使用tf.Variable来说,tf.name_scope和tf.variable_scope功能一样,都是给变量加前缀,相当于分类管理,模块化。
  2. tf.get_variable用于创建/获取一个变量,并且不受name_scope的约束。根据是否设置为共享变量,如果未设置共享变量,若变量不存在,则创建变量;当变量已存在,则报错;如果设置了共享变量,若变量不存在,则报错;当变量已存在,则获取变量。
  3. tf.get_variable和tf.variable_scope必须要搭配使用,为变量共享提供支持。
  4. 如果想实现“有则共享,无则新建”的方式,可以使用with tf.variable_scope(’’, reuse=tf.AUTO_REUSE)
  5. tf.name_scope主要用于管理一个图里面的各种op,返回的是一个以scope_name命名的context manager。一个graph会维护一个namespace的堆,每一个namespace下面可以定义各种op或者子namespace,实现一种层次化有条理的管理,避免各个op之间命名冲突。
  6. tf.variable_scope一般与tf.name_scope配合使用,用于管理一个graph中变量的名字,避免变量之间的命名冲突。同时用于允许在一个variable_scope下面共享变量。
  7. 当多个tf_variable_scope嵌套时,如果中间某层开启了reuse=True, 则内层自动全部共享,即使内层设置了reuse=False。而且,一旦使用tf.get_variable_scope().reuse_variables()打开了当前域共享,就不能关闭了!

       最后放上一张图:

5. 参考

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值