tensorflow学习之cifar_10模型评估

# -*- coding: utf-8 -*-

import  new_cifar10
import tensorflow as tf
import  time
import datetime
FLAGS=tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('eval_dir','/tmp/cifar10_eval',"""Directory where to write event logs.""")
tf.app.flags.DEFINE_string('eval_data','test',"""Either 'test' or 'train_eval'.""")
tf.app.flags.DEFINE_string('checkpoint_dir','/tmp/cifar10_train',"""Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_integer('eval_interval_secs',60*5,"""How often to run the eval.""")
tf.app.flags.DEFINE_integer('num_examples',10000,"""Number of examples to run.""")
tf.app.flags.DEFINE_boolean('run_once',False,"""Where to run eval only once.""")

def eval_once(saver,summary_writer,top_k_op,summary_op):
    with tf.Session() as sess:
        ckpt=tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess,ckpt.model_checkpoint_path)  #恢复模型变量到当前session中
            global_step=ckpt.model_checkpoint_path.split('/')[-1]/split('-')[-1] #提取global_step

        else:
            print('No checkpoint file found')
            return

        coord=tf.train.Coordinator()  #启动很多进程 返回coordinator类对象,可以用来coordinate很多线程的结束

        try:
            threads=[]  #使用coord管理进程-
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(qr.create_threads(sess,coord=coord,daemon=True,start=True)) #计算测试数据块的个数,并向上取整

            num_iter=int(math.ceil(FLAGS.num_examples/FLAGS.batch_size))
            true_count=0
            total_sample_count=num_iter*FLAGS.batch_size
            step=0
            while step<num_iter and not coord.should_stop():
                predictions=sess.run([top_k_op])
                true_count+=np.sum(predictions)
                step+=1

            precison=true_count/total_sample_count
            print('%s:precision @ 1=%.3f'%(datetime.now),precison)

            summary=tf.Summary()
            summary.ParseFromString(sess.run(summary_op))
            summary.value.add(tag='Precision @ 1',simple_value=precison)
            summary_writer.add_summary(summary,global_step)
        except Exception as e:
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads,stop_grace_period_secs=10)

def evaluate():
    with tf.Graph().as_default() as g:
        eval_data=FLAGS.eval_data=='test'
        images,labels=new_cifar10.inputs(eval_data=eval_data)  #读入测试数据集
        logits=new_cifar10.inference(images)  #使用moving average操作前模型参数,计算模型输出
        top_k_op=tf.nn.in_top_k(logits,labels,1) #恢复移动平均操作后的模型参数
        variable_averages=tf.train.ExponentialMovingAverage(new_cifar10.MOVING_AVERAGE_DEVAY) #恢复moving average操作后的模型参数
        variables_to_restore=variable_averages.variables_to_restore()  #返回要恢复的names到Variables的映射,也即一个map映射
        saver=tf.train.Saver(variables_to_restore)
        summary_op=tf.summary.merge_all()  ##创建序列化后的summary对象
        summary_writer=tf.summary.FileWriter(FLAGS.eval_dir,g)   #创建一个event file,用于之后写summary对象到logdir目录下的文件中
        while True:
            eval_once(saver,summary_writer,top_k_op,summary_op)
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)

def main(argv=None):
    new_cifar10.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.eval_dir):
        tf.gfile.DeleteRecursively(FLAGS.eval_dir)
    tf.gfile.MakeDirs(FLAGS.eval_dir)
    evaluate()

if __name__=='__main__':
    tf.app.run()

上午下载被中断之后,重新下载也有问题,一直下不完整,
下完整了,多线程还出现问题。。。
FIX YOU!!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值