1、首先确定模型参数要保存的位置
这个位置可以根据自己的需求来进行设定,下面只是举了一个例子。
checkpoint_dir = './training_checkpoints'
2、文件名
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
3、训练完成之后直接对模型进行保存
model.save_weights(checkpoint_prefix.format(epoch=epoch))
上面的epoch是要和训练的伦次相结合来进行使用的,就是可以设置你自己想要的伦次来进行模型的保存。
4、恢复模型
# 恢复模型结构
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
# 从检测点中获得训练后的模型参数
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
在恢复模型结构的过程中,不能改变模型训练时原有的结构,但是能改变传进去数据的结构,恢复模型的第二步就是把最后一个训练后的模型数据重新传到你构建的模型结构中。