tensorflow tf.Variable()和tf.get_variable()详解
一、tf.Variable()
tf.Variable(
initial_value=None, trainable=None, validate_shape=True, caching_device=None,
name=None, variable_def=None, dtype=None, import_scope=None, constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.compat.v1.VariableAggregation.NONE, shape=None
)
(1)参数说明:
initial_value:默认为None,可以搭配tensorflow随机生成函数。
trainable:默认为True,可以后期被算法优化的。如果不想该变量被优化,改为False。
validate_shape:如果为False,则允许使用未知形状的值初始化变量。如果为True,则默认为initial_value的形状必须已知
name:默认为None,给变量确定名称。
dtype:如果设置,则initial_value将转换为给定类型。如果为None,则保留数据类型(如果initial_value是Tensor),或者convert_to_tensor将决定
(2)使用tf.Variable定义变量常用的两种方式
1、用固定的值初始化变量
w = tf.Variable(initial-value,name=optional-name)
2、用tf的初始化器初始化变量
w = tf.Variable(tf.truncated_normal([3,4],mean=0,stddev=.5),name=‘weight’)
用tf的初始化器initializer op初始化变量必须指定变量shape,用name指定名称
(3)执行变量初始化的三种方式
在使用变量之前必须要进行初始化,初始化的方式有三种:
1、在会话中执行
tf.global_variable_initializer().run()
2、从文件中恢复,如restore from checkpoint
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
else:
print('No checkpoint file found')
变量初始化方式1、2也就是在模型训练和测试中常用的两种方式,在模型训练时,需要随机给模型赋初值,使用tf.global_variable_initializer().run()去初始化变量,在模型测试(或者进行fine-tune)时,使用初始化方式2从保存的ckpt中初始化变量。
3、也可自己通过tf.assign()给变量附初值,
实际上用initializer初始化方法给变量赋初值就是调用tf.assign()将变量的值赋给变量,可以自己调用tf.assign()给变量赋初值;
a = tf.Variable(1.0)
a = tf.assign(a,5.0)
b = tf.Variable(2.0)
b = tf.assign(b,6.0)
c = a+b
with tf.Session() as sess:
#相当于调用tf.assign()给变量赋初值
sess.run([a,b])
print(c.eval())
二 tf.get_variable()
常用的参数有:名称name、变量规格shape、变量类型dtype、变量初始化方式initializer、所属于的集合collections。
get_variable(
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
constraint=None
)
name:新变量或现有变量的名称。
shape:新变量或现有变量的形状。
dtype:新变量或现有变量的类型(默认为DT_FLOAT)。
ininializer:如果创建了则用它来初始化变量。
该函数的作用是创建新的tensorflow变量,
常见的initializer有:
常量初始化器tf.constant_initializer、
正太分布初始化器tf.random_normal_initializer、
截断正态分布初始化器tf.truncated_normal_initializer、
均匀分布初始化器tf.random_uniform_initializer。
import tensorflow as tf;
import numpy as np;
#常量初始化器
v1_cons = tf.get_variable('v1_cons', shape=[1,4], initializer=tf.constant_initializer())
v2_cons = tf.get_variable('v2_cons', shape=[1,4], initializer=tf.constant_initializer(9))
#正太分布初始化器
v1_nor = tf.get_variable('v1_nor', shape=[1,4], initializer=tf.random_normal_initializer())
v2_nor = tf.get_variable('v2_nor', shape=[1,4], initializer=tf.random_normal_initializer(mean=0, stddev=5, seed=0))#均值、方差、种子值
#截断正态分布初始化器
v1_trun = tf.get_variable('v1_trun', shape=[1,4], initializer=tf.truncated_normal_initializer())
v2_trun = tf.get_variable('v2_trun', shape=[1,4], initializer=tf.truncated_normal_initializer(mean=0, stddev=5, seed=0))#均值、方差、种子值
#均匀分布初始化器
v1_uni = tf.get_variable('v1_uni', shape=[1,4], initializer=tf.random_uniform_initializer())
v2_uni = tf.get_variable('v2_uni', shape=[1,4], initializer=tf.random_uniform_initializer(maxval=-1., minval=1., seed=0))#最大值、最小值、种子值
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("常量初始化器v1_cons:",sess.run(v1_cons))
print("常量初始化器v2_cons:",sess.run(v2_cons))
print("正太分布初始化器v1_nor:",sess.run(v1_nor))
print("正太分布初始化器v2_nor:",sess.run(v2_nor))
print("截断正态分布初始化器v1_trun:",sess.run(v1_trun))
print("截断正态分布初始化器v2_trun:",sess.run(v2_trun))
print("均匀分布初始化器v1_uni:",sess.run(v1_uni))
print("均匀分布初始化器v2_uni:",sess.run(v2_uni))
输出结果:
常量初始化器v1_cons: [[0. 0. 0. 0.]]
常量初始化器v2_cons: [[9. 9. 9. 9.]]
正太分布初始化器v1_nor: [[-0.7286455 -0.03095582 1.6400269 -0.90134907]]
正太分布初始化器v2_nor: [[-1.9957879 10.522196 0.8553612 2.7325907]]
截断正态分布初始化器v1_trun: [[-0.52284956 -0.77045 1.9507815 0.96106136]]
截断正态分布初始化器v2_trun: [[-1.9957879 0.8553612 2.7325907 2.1127698]]
均匀分布初始化器v1_uni: [[0.5369104 0.05912018 0.1587832 0.2859378 ]]
均匀分布初始化器v2_uni: [[ 0.79827476 -0.9403336 -0.69752836 0.9034374 ]]
可能引发的异常:
ValueError:当创建新的变量和形状时,在变量创建时违反重用,或当 initializer 的 dtype 和 dtype 不匹配时。可在 variable_scope 中设置重用。