tensorflow学习笔记11——开始run

本文深入解析了使用TensorFlow进行机器学习模型训练的关键步骤,包括会话管理、数据集迭代器使用、摘要合并及结果记录等核心环节。通过具体代码示例,详细介绍了如何运用sess.run()获取变量值或执行操作,数据集迭代器的初始化,以及如何利用summary合并所有汇总并将其写入事件文件,便于后续使用TensorBoard进行可视化分析。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、sess.run()

# Create the session and run the graph
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)

用sess.run()有两种情况:
1.想要获取某个变量的时候:
2.执行某种操作的时候,这个操作不是一个变量,没有值。比如上图为了初始化全部变量。

2、讲一下DataSet的iterator
参考:https://blog.youkuaiyun.com/briblue/article/details/80962728

注意这段代码前面是还有这么一句的

iterator = train_dataset.make_initializable_iterator()

2、summary.merge_all()合并默认图形中的所有汇总.
merged_summaries是一个节点,必须先传入session.run()运行才能获得真正的汇总!

summary.FileWriter():将汇总结果写入事件(event file)

# Merge all the summary and write
summary_op = tf.compat.v1.summary.merge_all()
train_filewriter = tf.compat.v1.summary.FileWriter('train/', sess.graph)
saver=tf.compat.v1.train.Saver(max_to_keep=1)   #只保留最后一代模型,如果想保留全部,那就把max_to_keep=0

3、truepredictNum += np.sum(predictValue == testValue)计算预测正确的数量
accuracy1 = truepredictNum / 5000.0 #正确率

while (True):
    try:

        lossValue, lr, _ = sess.run([loss, learning_rate, opt_op])   #这里如果改成lossValue, lr = sess.run([loss, learning_rate])

        if step % 100 == 0:
            print("step %i: Learning_rate: %f Loss: %f" % (step, lr, lossValue))

        if step % 1000 == 0:
            saver.save(sess, 'model/my-model', global_step=step)
            truepredictNum = 0
            sess.run([testiterator.initializer, validiterator.initializer])
            accuracy1 = 0.0
            accuracy2 = 0.0

            while (True):
                try:
                    #在验证数据集上预测
                    predictValue, testValue = sess.run([validresult, validrecord_labels])
                    truepredictNum += np.sum(predictValue == testValue)
                except tf.errors.OutOfRangeError:
                    print("valid correct num: %i" % (truepredictNum))

                    accuracy1 = truepredictNum / 5000.0
                    break

            truepredictNum = 0

            while (True):
                try:
                    #在测试数据集上预测
                    predictValue, testValue = sess.run([testresult, testrecord_labels])
                    truepredictNum += np.sum(predictValue == testValue)
                except tf.errors.OutOfRangeError:
                    print("test correct num: %i" % (truepredictNum))

                    accuracy2 = truepredictNum / 10000.0
                    break

            summary = sess.run(summary_op, feed_dict={valid_accuracy: accuracy1, test_accuracy: accuracy2})
            train_filewriter.add_summary(summary, step)
        step += 1

    except tf.errors.OutOfRangeError:
        break

4、add_summary()将训练过程数据保存在filewriter指定的文件中
回头用tensorboard画图

valid_accuracy = tf.placeholder(tf.float32)
test_accuracy = tf.placeholder(tf.float32)
summary = sess.run(summary_op, feed_dict={valid_accuracy: accuracy1, test_accuracy: accuracy2})
train_filewriter.add_summary(summary, step)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值