tensorflow中如何直接读取网络的参数(weight and bias)的值

本文介绍了如何检查神经网络训练过程中参数的变化情况。首先通过自定义函数获取特定变量名的参数,然后利用TensorFlow会话读取这些参数的具体数值。这对于验证网络是否正常训练至关重要。

训练好了一个网络,想要查看网络里面参数是否经过BP算法优化过,可以直接读取网络里面的参数,如果一直是随机初始化的值,则证明训练代码有问题,需要改。下面介绍如何直接读取网络的weight 和 bias。
(1) 获取参数的变量名。可以使用一下函数获取变量名:

def vars_generate1(self,scope_name_var):
    return [var for var in tf.global_variables() if scope_name_var in var.name ]

输入你想要读取的变量的一部分的名称(scope_name_var),然后通过这个函数返回一个List,里面是所有含有这个名称的变量。

(2) 利用session读取变量的值:

def get_weight(self):
  full_connect_variable = self.vars_generate1("pred_network/full_connect/l5_conv")
  with tf.Session() as sess:
     sess.run(tf.global_variables_initializer())  ##一定要先初始化变量
     print(sess.run(full_connect_variable[0]))

之后如果想要看参数随着训练的变化,你可以将这些参数保存到一个txt文件里面查看。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值