TensorFlow -- 训练MNIST数据集

本文介绍了一个基于TensorFlow的手写数字识别模型实现过程。该模型利用MNIST数据集进行训练,并采用Softmax分类器完成预测任务。通过25轮迭代训练后,在测试集上达到了较好的准确率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# -*- coding:utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import pylab

# 下载并解压数据
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 重置图
tf.reset_default_graph()

# 定义占位符
# 数据集的维度是28*28=874
x = tf.placeholder(tf.float32, [None, 784])
# 共10个类别
y = tf.placeholder(tf.float32, [None, 10])

# 定义学习参数
W = tf.Variable(tf.random_normal(([784, 10])))
b = tf.Variable(tf.zeros([10]))

# 正向传播
# softmax分类器
pred = tf.nn.softmax(tf.matmul(x, W) + b)

# 反向结构
# 损失函数
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

# 定义参数
learning_rate = 0.01
# 使用梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# 迭代次数
training_epochs = 25
batch_size = 100
display_step = 1

saver = tf.train.Saver()
model_path = 'log/521model.ckpt'

# 启动session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # 启动循环开始训练
    for epoch in range(training_epochs):
        avg_cost = 0

        # total_batch=550
        total_batch = int(mnist.train.num_examples/batch_size)
        # 循环所有的数据集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # 运行优化器
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})

            # 计算平均的loss值
            avg_cost += c / total_batch

        # 显示训练中的详细信息
        if (epoch+1) % display_step == 0:
            print('Epochs:', '%04d' % (epoch+1), 'cost:', '{:.9f}'.format(avg_cost))
    print('Finished!')

    # 测试模型
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 计算精确度
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print('Accurary:', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

    # 保存模型
    save_path = saver.save(sess, model_path)
    print('Model saved in file: %s' % save_path)

# 读取模型
print('Starting loading model')
with tf.Session() as sess:
    # 初始化变量
    sess.run(tf.global_variables_initializer())
    # 恢复模型变量
    saver.restore(sess, model_path)

    # 测试模型
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print('Accuracy:', accuracy.eval({x: mnist.test.images, y:mnist.test.labels}))

    output = tf.argmax(pred, 1)
    batch_xs, batch_ys = mnist.train.next_batch(2)
    outputval, predv = sess.run([output, pred], feed_dict={x: batch_xs, y:batch_ys})
    print(outputval, predv, batch_ys)

    im = batch_xs[0]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()

    im = batch_xs[1]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值