slim训练获取session,自定义log输出

本文详述了如何在使用TensorFlow的slim模块进行模型训练时,通过自定义train_step_fn函数来控制和打印日志信息。利用slim.train()函数,结合train_step函数,可以实现在特定步骤打印学习率等关键信息,为模型训练过程提供更详细的监控。

tensorflow的slim模块封装了大量的模块,使用起来简单方便,但有时候想要控制log的打印却也是找不到相应的API,使用slim训练模型时,print语句往往是无法达到预期效果的,下面介绍如何使用train_step控制log

1. slim.train

   slim.train()函数定义如下,初始化模型和参数后,需要调用下面的函数进行模型训练:

def train(train_op,
          logdir,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_steps=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=_USE_DEFAULT,
          init_fn=None,
          ready_op=_USE_DEFAULT,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          summary_writer=_USE_DEFAULT,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None,
          session_config=None,
          session_wrapper=None,
          trace_every_n_steps=None,
          ignore_live_threads=False)

接下来需要获取sess的控制权,通过session打印log输出,同时要定义slim.train()中的train_step_fn函数,在该函数中调用train_step函数,然后实现自己想要的逻辑

2. train_step

from tensorflow.contrib.slim.python.slim.learning import train_step


def train_step_fn(session,  *xarg, **train_step_kwargs):
     total_loss, should_stop = train_step(session, *xarg, **train_step_kwargs)
     if train_step_fn.step % 4 ==0:
        tf.logging.info(session.run(train_step_fn.lr))

train_step_fn.lr = learning_rate
train_step_fn.step = step

final_loss = slim.learning.train(train_op, TRAIN_LOG, 
                        train_step_fn=train_step_fn,
                        init_fn=init_fn,
                        global_step=global_step,
                        number_of_steps=steps,
                        save_summaries_secs=60,
                        save_interval_secs=600,
                        session_config=sess_config,
                      )

这样就能获取sess的控制权,完成自定义log的打印

参考:https://github.com/google-research/tf-slim/blob/master/tf_slim/learning.py

https://xbuba.com/questions/48898117

评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值