为了实现finetune有如下两种解决方案:
model_fn里面定义好模型之后直接赋值
def model_fn(features, labels, mode, params):
# .....
# finetune
if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
checkpoint_path = None
if tf.gfile.IsDirectory(params.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
else:
checkpoint_path = params.checkpoint_path
tf.train.init_from_checkpoint(
ckpt_dir_or_file=checkpoint_path,
assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
)
使用钩子 hooks。
可以在定义tf.contrib.learn.Experiment
的时候通过train_monitors
参数指定
# Define the experiment
experiment = tf.contrib.learn.Experiment(
estimator=estimator, # Estimator
train_input_fn=train_input_fn, # First-class function
eval_input_fn=eval_input_fn, # First-class function
train_steps=params.train_steps, # Minibatch steps
min_eval_frequency=params.eval_min_frequency, # Eval frequency
# train_monitors=[], # Hooks for training
# eval_hooks=[eval_input_hook], # Hooks for evaluation
eval_steps=params.e