代码:
import tensorflow as tf
import matplotlib.pyplot as pyplot
fileQueue = tf.train.string_input_producer(["e:\\simdata.tf"])
tfReader = tf.TFRecordReader()
_, item = tfReader.read(fileQueue)
feature = tf.parse_single_example(item, features={"x":tf.FixedLenFeature([], tf.float32),
"y": tf.FixedLenFeature([], tf.float32)})
a = tf.placeholder(shape=[None,1],name="input", dtype=tf.float32)
b = tf.placeholder(shape=[None,1],name="input", dtype=tf.float32)
Layer = tf.layers.dense(a, 32, tf.nn.tanh,kernel_initializer=tf.random_normal_initializer(stddev=0.1))
Layer1= tf.layers.dense(Layer, 32, tf.nn.tanh ,kernel_initializer=tf.random_normal_initializer(stddev=0.1) )
Layer2= tf.layers.dense(Layer1, 32, tf.nn.tanh ,kernel_initializer=tf.random_normal_initializer(stddev=0.1) )
Layer3= tf.layers.dense(Layer2, 1, tf.nn.tanh ,kernel_initializer=tf.random_normal_initializer(stddev=0.1) )
loss = tf.reduce_mean(tf.square( Layer3 - b))
global_step = tf.Variable(0)
learning_rate = tf.train.exponential_decay(0.01, global_step, 1000, 0.9, staircase=True)
train = tf.train.AdamOptimizer(learning_rate).minimize(loss,global_step=global_step)
sw = tf.summary.FileWriter("e:/log")
tf.summary.scalar("loss", loss)
summall = tf.summary.merge_all()
saver = tf.train.Saver(max_to_keep=2);
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.restore(sess,"e:/save/save1.data-54100")
tf.train.start_queue_runners(sess=sess)
lossAry = [];
losstemp = 0;
runary = []
runnumber = 0
checkX = []
checkY=[]
for i in range(-314, 314,1):
temp = (sess.run(feature));
checkX.append([temp["x"]])
checkY.append([temp["y"]])
for t in range(1000):
for j in range(100):
tempA = []
tempB = []
for i in range(100):
temp = (sess.run(feature));
tempA.append([temp["x"]])
tempB.append([temp["y"]])
_ ,losstemp,summallTemp= sess.run([train,loss,summall], feed_dict={ a:tempA,
b:tempB})
sw.add_summary(summallTemp, j)
runnumber = runnumber + 100
runary.append(runnumber)
temp = sess.run(loss, feed_dict={a:checkX, b:checkY})
lossAry.append(temp)
pyplot.cla()
pyplot.plot(runary,lossAry )
pyplot.show()
saver.save(sess, "e:/save/save1.data",global_step=runnumber)
sw.close();
损失函数下降曲线:

效果
