基于tensorflow框架的模型参数保存、重载及输出
环境介绍
编程语言:Python3.5
框架采用:Tensorflow-gpu = 1.1.0
需求驱动
在训练模型过程中不免会遇到需要存储模型参数的情况,在tensorflow框架下提供和Saver.save()函数来保存参数,保存的对象包括:权重及在程序中定义的变量,不包含图结构,保存的文件为checkpoint 文件。
代码示例
变量的定义(要记得导入需要的库,如tensorflow, numpy等)
import tensorflow as tf
import os
...
w1 = tf.Variable(tf.random_normal(shape))
w2 = tf.Variable(tf.random_normal(shape))
保存变量, 重载(如果存在已经保存的chenkpoint文件)
#-------定义文件存储路径-------#
ckpt_path = './ckpt_path'
if not os.path.exists(ckpt_path