#!/usr/bin/python3
# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
import argparse
tf.reset_default_graph()
parser = argparse.ArgumentParser()
parser.add_argument("-md", "--model_name", help="The model name",type=str,default="model.ckpt")
args = parser.parse_args()
print("args:",args)
logdir='./output/'
x=tf.Variable([[3,4],[4,5]],dtype=tf.float32,name='x')
sess=tf.InteractiveSession(graph=tf.get_default_graph())
saver = tf.train.Saver()
# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
tf.global_variables_initializer().run()
print(x.eval())
assign_op = tf.assign(x, x + 1)
sess.run(assign_op)
saver.save(sess, logdir + args.model_name, global_step=10)
print('-------------')
sess.close()
结果:
第一次运行:
第二次运行: