Tensorflow学习笔记——tf.Variable()和tf.get_variable()

本文介绍了tf.Variable()和tf.get_variable的相关内容。tf.Variable是一个类,可创建图中的变量,常用参数有初始化值和名称,使用前需初始化;tf.get_variable用于获取或创建变量。二者区别在于,tf.Variable会自行处理命名冲突,而tf.get_variable会报错,共享变量时用tf.get_variable。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.Variable()

tf.Variable是一个Variable类
通过variable维持图graph的状态,以便在sess.run()中执行,可以用Variable类创建一个实例在图中增加变量

tf.Variable(
initial_value=None, 
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None
)
  • initial_value
    Tensor或可转换为Tensor的Python对象,它是Variable的初始值。除非validate_shape设置为False,否则初始值必须具有指定的形状。也可以是一个可调用,没有参数,在调用时返回初始值。在这种情况下,必须指定dtype。
  • name
    变量的可选名称。默认为“Variable”并自动获取。
  • dtype
    如果设置,则initial_value将转换为给定类型。如果为None,则保留数据类型(如果initial_value是Tensor),或者convert_to_tensor将决定。
  • collection
    一个图graph集合列表的关键字。新变量将添加到这个集合中。默认为[GraphKeys.GLOBAL_VARIABLES]。也可自己指定其他的集合列表

一般常用的参数包括初始化值和名称(是该变量的唯一索引)
在使用变量之前必须要进行初始化

tf.get_variable

获取一个已经存在的变量或者创建一个新的变量

get_variable(
	name,
	shape=None,
	dtype=None,
	initializer=None,
	regularizer=None,
	trainable=True,
	collections=None,
	caching_device=None,
	partitioner=None,
	validate_shape=True,
	use_resource=None,
	custom_getter=None,
	constraint=None
)
  • name
    新变量或现有变量的名称
  • shape
    新变量或现有变量的形状
  • dtype
    新变量或现有变量的类型
  • initializer
    如果创建了则用它初始化变量
  • collection
    一个图graph集合列表的关键字。新变量将添加到这个集合中。默认为[GraphKeys.GLOBAL_VARIABLES]。也可自己指定其他的集合列表

区别

  1. 使用tf.Variable时,如果检测到命名冲突,系统会自己处理。使用tf.get_variable()时,系统不会处理冲突,而会报错。
  2. 基于这两个函数的特性,当我们需要共享变量的时候,需要使用tf.get_variable()。在其它两种情况下,这两个的用法是一样的。

Reference:
https://blog.youkuaiyun.com/MrR1ght/article/details/81228087
https://www.jianshu.com/p/2061b221cd8f?utm_campaign=maleskine&utm_content=note&utm_medium=seo_notes&utm_source=recommendation

在神经网络训练完成后,我们可以使用模型对新的输入数据进行预测。以下是使用 TensorFlow 进行简单神经网络训练测试的示例代码: ```python import tensorflow as tf # 定义输入数据 x = tf.placeholder(tf.float32, [None, input_size]) # 定义模型参数 W = tf.Variable(tf.zeros([input_size, num_classes])) b = tf.Variable(tf.zeros([num_classes])) # 定义模型 y = tf.nn.softmax(tf.matmul(x, W) + b) # 定义损失函数 y_ = tf.placeholder(tf.float32, [None, num_classes]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) # 定义优化器 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy) # 训练模型 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(num_iterations): batch_xs, batch_ys = get_next_batch(train_data, train_labels, batch_size) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # 测试模型 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print("Accuracy: ", sess.run(accuracy, feed_dict={x: test_data, y_: test_labels})) ``` 在训练完成后,我们可以使用 `sess.run()` 函数来获取模型的输出。例如,如果我们想要对一个新的数据样本进行预测,我们可以将其传递给 `sess.run()` 函数的 `feed_dict` 参数,并获取模型的输出 `y`。下面是一个简单的示例: ```python # 定义要预测的数据 new_data = [[5.1, 3.5, 1.4, 0.2]] # 获取预测结果 prediction = sess.run(y, feed_dict={x: new_data}) # 打印预测结果 print(prediction) ``` 在上面的代码中,我们将 `new_data` 作为输入数据传递给模型,并使用 `sess.run()` 函数获取模型的输出 `y`。然后,我们打印预测结果即可。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值