import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32,shape=[None,1])
y = 4*x+4
w = tf.Variable(tf.random_normal([1],-1,1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x+b
loss = tf.reduce_mean(tf.square(y-y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
# 决定训练还是测试
isTrain = False
train_steps = 200
checkpoint_steps = 50
checkpoint_dir = './save_test/'
saver = tf.train.Saver()#默认保存所有的变量
x_data = np.reshape(np.random.rand(10).astype(np.float32),(10,1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 当在训练100次时发生中断,找到已经存储的模型继续训练
# saver.restore(sess, checkpoint_dir+'model.ckpt-100')
if isTrain:
for i in range(train_steps):
sess.run(train,feed_dict={x:x_data})
if (i+1)%checkpoint_steps==0:
saver.save(sess,checkpoint_dir+'model.ckpt',global_step=i+1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))