tensolrflow中get_variable和tf.Variable区别

本文详细探讨了TensorFlow中tf.Variable和tf.get_variable的区别。tf.Variable主要用于创建新变量,而tf.get_variable则涉及变量的复用和命名空间管理。在使用tf.get_variable时,通过设置reuse参数可以控制是否重复使用或创建新的变量,常与tf.variable_scope结合以实现变量共享。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

技术交流qq群: 659201069

  先来看下二者的定义:
  
Variable:必须给定的参数只有一个initial_value,如果名字没指定会自己生成一个Variable:0类似于此

  def __init__(self,
               initial_value=None,
               trainable=True,
               collections=None,
               validate_shape=True,
               caching_device=None,
               name=None,
               variable_def=None,
               dtype=None,
               expected_shape=None,
               import_scope=None):

get_variable:必须给定的参数只有name

def get_variable(name,
                 shape=None,
                 dtype=None,
                 initializer=None,
                 regularizer=None,
                 trainable=True,
                 collections=None,
                 caching_device=None,
                 partitioner=None,
                 validate_shape=True,
                 use_resource=None,
                 custom_getter=None):

  tf.Variable是用于生成一个新的变量,tf.get_variable是获取一个已经存在的变量,如果不存在再重建,在tensorflow中每个变量都有一个scope也就是命名空间,一个变量要的全名由scope和name唯一确定,因为tf.Variable不涉极变量复用问题,命名空间意义不是太明显,只有在用tf.get_variable时才会体会用命名空间的作用。
  

tf.Variable看使用详解:

import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;

a1 = tf.Variable(initial_value=[[1.0,2.0,3.0],[1.0,2.0,3.0]],name="a1")
a2 = tf.Variable(initial_value=[[1.0,2.0,3.0],[1.0,2.0,3.0]],name="a2")
a3 = tf.Variable(initial_value=[[1.0,2.0,3.0],[1.0,2.0,3.0]],name="a1")
with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(a1))
        print(sess.run(a2))
        print(sess.run(a3))
        print(a1)
        print(a2)
        print(a3)

输出如下所示:
这里写图片描述
 此处没有调用with tf.variable_scope(“scope_1”):,可以认为所有变量的scope为空,a3定义时name=”a1”,在这种情况下tensorflow会把a3重命名。也就是tf.Variable无论何时都能创建成功,无论是否有同名变量
 
  

tf.get_variable使用详解:

# a1 = tf.get_variable(name='a1', shape=[2, 3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
with tf.Session() as sess:
    with tf.variable_scope("scope_1", reuse=False):
        a1 = tf.get_variable(name='a1', shape=[2, 3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
        # a1 = tf.get_variable(name='a1')
        sess.run(tf.global_variables_initializer())
        print(sess.run(a1))

输出如下所示:
这里写图片描述

 把with tf.variable_scope(“scope_1”, reuse=False):中的resue设为True则上面的代码运行报错,这说明只有当 reuse=False时才可以创建变量, reuse=True时只能获取已有的变量或者通过tf.Variable创建新的变量,使用tf.get_variable一般都是和with tf.variable_scope(“scope_1”, reuse=True)配合使用的,共享已有变量,如果需要新变量则通过tf.Variable来创建

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

阿童木-atom

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值