转载:https://blog.youkuaiyun.com/akadiao/article/details/78517154
tensorflow中tf.get_variable的API为
def get_variable(name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None):123456789101112
常用的参数有:
name:变量名称
shape:变量维度
initializer:变量初始化方式
regularizer:正规化
caching_device:可选的设备字符串或函数描述
其中,变量的初始化方式有
tf.constant_initializer–常量初始化
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
# 默认值为0
v1 = tf.get_variable('v1', shape=[5], initializer=tf.constant_initializer())
# 也可以指定初始化值
v2 = tf.get_variable('v2', shape=[5], initializer=tf.constant_initializer(9.))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print v1.eval()
print v2.eval()123456789101112131415
输出:
[ 0. 0. 0. 0. 0.]
[ 9. 9. 9. 9. 9.]
---------------------
作者:阿卡蒂奥
来源:优快云
原文:https://blog.youkuaiyun.com/akadiao/article/details/78517154
版权声明:本文为博主原创文章,转载请附上博文链接!