# -*- coding: utf-8 -*-
import new_cifar10
import tensorflow as tf
import time
import datetime
FLAGS=tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('eval_dir','/tmp/cifar10_eval',"""Directory where to write event logs.""")
tf.app.flags.DEFINE_string('eval_data','test',"""Either 'test' or 'train_eval'.""")
tf.app.flags.DEFINE_string('checkpoint_dir','/tmp/cifar10_train',"""Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_integer('eval_interval_secs',60*5,"""How often to run the eval.""")
tf.app.flags.DEFINE_integer('num_examples',10000,"""Number of examples to run.""")
tf.app.flags.DEFINE_boolean('run_once',False,"""Where to run eval only once.""")
def eval_once(saver,summary_writer,top_k_op,summary_op):
with tf.Session() as sess:
ckpt=tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path) #恢复模型变量到当前session中
global_step=ckpt.model_checkpoint_path.split('/')[-1]/split('-')[-1] #提取global_step
else:
print('No checkpoint file found')
return
coord=tf.train.Coordinator() #启动很多进程 返回coordinator类对象,可以用来coordinate很多线程的结束
try:
threads=[] #使用coord管理进程-
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess,coord=coord,daemon=True,start=True)) #计算测试数据块的个数,并向上取整
num_iter=int(math.ceil(FLAGS.num_examples/FLAGS.batch_size))
true_count=0
total_sample_count=num_iter*FLAGS.batch_size
step=0
while step<num_iter and not coord.should_stop():
predictions=sess.run([top_k_op])
true_count+=np.sum(predictions)
step+=1
precison=true_count/total_sample_count
print('%s:precision @ 1=%.3f'%(datetime.now),precison)
summary=tf.Summary()
summary.ParseFromString(sess.run(summary_op))
summary.value.add(tag='Precision @ 1',simple_value=precison)
summary_writer.add_summary(summary,global_step)
except Exception as e:
coord.request_stop(e)
coord.request_stop()
coord.join(threads,stop_grace_period_secs=10)
def evaluate():
with tf.Graph().as_default() as g:
eval_data=FLAGS.eval_data=='test'
images,labels=new_cifar10.inputs(eval_data=eval_data) #读入测试数据集
logits=new_cifar10.inference(images) #使用moving average操作前模型参数,计算模型输出
top_k_op=tf.nn.in_top_k(logits,labels,1) #恢复移动平均操作后的模型参数
variable_averages=tf.train.ExponentialMovingAverage(new_cifar10.MOVING_AVERAGE_DEVAY) #恢复moving average操作后的模型参数
variables_to_restore=variable_averages.variables_to_restore() #返回要恢复的names到Variables的映射,也即一个map映射
saver=tf.train.Saver(variables_to_restore)
summary_op=tf.summary.merge_all() ##创建序列化后的summary对象
summary_writer=tf.summary.FileWriter(FLAGS.eval_dir,g) #创建一个event file,用于之后写summary对象到logdir目录下的文件中
while True:
eval_once(saver,summary_writer,top_k_op,summary_op)
if FLAGS.run_once:
break
time.sleep(FLAGS.eval_interval_secs)
def main(argv=None):
new_cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.DeleteRecursively(FLAGS.eval_dir)
tf.gfile.MakeDirs(FLAGS.eval_dir)
evaluate()
if __name__=='__main__':
tf.app.run()
上午下载被中断之后,重新下载也有问题,一直下不完整,
下完整了,多线程还出现问题。。。
FIX YOU!!!