TensorFlow系列——在自定义的标准estimator中使用tensorboard及打印中间数据

本文介绍了如何在TensorFlow的自定义标准estimator中利用hook钩子函数获取并记录中间数据,同时展示了如何将这些数据输出到TensorBoard以便于模型监控和分析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、定义hook钩子函数用于获取指定名称的中间数据

1、定义hook钩子类用于获取模型中指定名称的中间数据

class YourOwnHook(tf.train.SessionRunHook):
    def __init__(self):
        np.set_printoptions(suppress=True)
        np.set_printoptions(linewidth=400)

    def before_run(self, run_context):
        """返回SessionRunArgs和session run一起跑"""
        v1 = tf.get_collection('logis')
        prob = tf.get_collection('prob')
        return tf.train.SessionRunArgs(fetches=[v1, prob])
    def after_run(self, run_context, run_values):
        v1, batch_labels = run_values.results
        logger.info("logis value:{}".format(v1))
        print("prob :",batch_labels)

2、标准的自定义的estimator以及设置钩子用于输出到tensorboard以及输出中间值

class MyEstimator(tf.estimator.Estimator):
    def __init__(self,
                          model_dir,
                          hidden_units,
     
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值