一个tf Estimator Summary Hook 函数
应用场景:最近在使用tf.estimator.DNNLinearCombinedClassifier时,希望能够使用tensorboard去看一下loss的下降曲线。
虽然Tensorflow有提供了tf.train.SummarySaverHook,但是研究后发现,并不适用于tf.estimator.DNNLinearCombinedClassifier,只适用于自定义model_fn的时候。
于是经过研究,自己写了个hook函数,经过测试可以用。
class SummaryHook(tf.train.SessionRunHook):
"""
Generate summary
"""
def begin(self):
self._step = -1
self.summary_op = tf.summary.merge_all()
def after_create_session(self, session, coord):
self.summary_writer = tf.summary.FileWriter('./log_summary', session.graph)
#self.summary_writer.add_summary(self.summary_op)
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(self.summary_op)
def after_run(self, run_context, run_values):
flush_step = 50
if self._step % flush_step == 0:
print("Flusing summary")
self.summary_writer.add_summary(run_values[0], global_step=self._step)
self.summary_writer.flush()