tf.get_variable
对传入的name的tensor进行初始化,在同一个graph下,如果不存在则初始化,如果存在则返回
tf.compat.v1.get_variable(
name, shape=None, dtype=None, initializer=None, regularizer=None,
trainable=None, collections=None, caching_device=None, partitioner=None,
validate_shape=True, use_resource=None, custom_getter=None, constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.compat.v1.VariableAggregation.NONE
)
官方例子如下:
def foo():
with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
v = tf.get_variable("v", [1])
return v
v1 = foo() # Creates v.
v2 = foo() # Gets the same, existing v.
assert v1 == v2
有点像c++里面的单例模式
本文详细介绍了TensorFlow中tf.get_variable函数的使用方法,该函数能在同一个graph下对指定名称的tensor进行初始化,若已存在则直接返回。通过示例展示了如何在代码中实现变量的共享,类似于C++的单例模式。
1612

被折叠的 条评论
为什么被折叠?



