- 保存当前参数最好时的参数
saver_ckpt = tf.train.Saver(tf.trainable_variables())
if max_ndcg < ndcg:
max_ndcg = ndcg
max_res = cur_res
saver_ckpt.save(sess, ckpt_save_path+ckpt_save_file, global_step=epoch_count)
print("saved best", epoch_count)
- 加载该模型
ckpt_save_path = "Pretrain/dropout/%s_embed_%s/" % (args.dataset, "_".join(map(str, model.nc))) # '[32,32,32,32,32,32]'
saver_ckpt = tf.train.Saver(tf.trainable_variables())
#通过checkpoint获取最新的文件
ckpt = tf.train.get_checkpoint_state(ckpt_save_path)
saver_ckpt.restore(sess, ckpt.model_checkpoint_path)