3-Tensorflow-demo_10_02-批量和Epoch实现MNIST并可视化

本文介绍了一个使用TensorFlow实现的手写数字识别模型。通过批量训练和随机梯度下降(SGD),模型能够在MNIST数据集上达到90.23%的准确率。文章详细解释了如何划分数据集为批量,构建模型并进行训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import tensorflow as tf
import pprint
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import os
"""
批量和 SGD是绝配。
   1、你需要将数据集划分成 批量数据;比如样本是1000个 你的批量大小128,
     有7个128的batch, 还有一个104的batch。
     解决办法:tf.placeholder(shape=[None, 784])
"""


def batches(batch_size, features, labels):
    """
    实现批量数据的获取
    :param batch_size:
    :param features:
    :param labels:
    :return:
    """
    assert len(features) == len(labels)
    output_batches = []
    for start in range(0, len(features), batch_size):
        end = start + batch_size
        batch = [features[start: end], labels[start: end]]
        output_batches.append(batch)
    return output_batches


def get_batch(batch_size, features, labels):
    """
    构建获取批量数据的生成器 generator. (range())
    :param batch_size:
    :param features:
    :param labels:
    :return:
    """
    assert len(features) == len(labels)
    for start in range(0, len(features), batch_size):
        end = start + batch_size
        batch_x = features[start: end]
        batch_y = labels[start: end]
        yield batch_x, batch_y

# 测试样本
example_features = [
    ['F11','F12','F13','F14'],
    ['F21','F22','F23','F24'],
    ['F31','F32','F33','F34'],
    ['F41','F42','F43','F44']]
example_labels = [
    ['L11','L12'],
    ['L21','L22'],
    ['L31','L32'],
    ['L41','L42']]

my_graph = tf.Graph()
def mnist_batch_epochs():
    """
    批量+ epochs实现手写数据集的分类任务。
    :return:
    """
    with my_graph.as_default():
        # 0、定义模型的特征、类别等变量。
        n_input = 784    # 特征数量
        n_classes = 10  # 类别数量
        epochs = 20     # 迭代的次数 (1个epochs 代表所有样本 正向+反向传播 执行1次)
        batch_size = 128   # 每批训练时 批量的大小。
        lr = 0.01    # 学习率
        with tf.variable_scope('network'):
            # 1、输入的占位符(特征+标签)
            features = tf.placeholder(tf.float32, shape=[None, n_input], name='inputx')
            labels = tf.placeholder(tf.float32, shape=[None, n_classes], name='inputy')
            learning_rate = tf.placeholder(tf.float32, name='lr')

            # 2、模型的变量
            w = tf.get_variable(
                'weights', shape=[n_input, n_classes], dtype=tf.float32,
                initializer=tf.random_normal_initializer(stddev=0.1)
            )
            b = tf.get_variable(
                'bias', shape=[n_classes], dtype=tf.float32,
                initializer=tf.zeros_initializer()
            )
            # 3、正向传播,获取logits和模型的预测概率
            logits = tf.matmul(features, w) + b     # [N, n_classes]
            prediction = tf.nn.softmax(logits)

        with tf.name_scope('loss'):
            # 4、构建损失
            loss = tf.reduce_mean(
                -tf.reduce_sum(labels * tf.log(prediction), axis=1)
            )
            # 可视化损失
            tf.summary.scalar('loss', tensor=loss)

        with tf.name_scope('optimizer'):
            # 5、构建模型优化器
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
            train_opt = optimizer.minimize(loss=loss)

        with tf.name_scope('accuracy'):
            # 6、计算模型准确率
            correct_pred = tf.equal(tf.argmax(logits, axis=1), tf.argmax(labels, axis=1))
            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
            tf.summary.scalar('accuracy', tensor=accuracy)

        # 7、构建持久化对象
        saver = tf.train.Saver(max_to_keep=2)
        checkpoint_path = './models/mnist'
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)
            print('创建文件夹:{}'.format(checkpoint_path))

    # 二、会话
    with tf.Session(graph=my_graph) as sess:
        # a、变量的初始化 或者 加载持久化文件
        sess.run(tf.global_variables_initializer())
        # b、加载真正的数据
        mnist = input_data.read_data_sets('../datas/mnist', one_hot=True)

        # 读取训练、验证、测试数据。(mnist数据已经做过归一化的, 而且随机打乱过)
        train_features = mnist.train.images
        val_features = mnist.validation.images
        test_features = mnist.test.images

        train_labels = mnist.train.labels.astype(np.float32)
        val_labels = mnist.validation.labels.astype(np.float32)
        test_labels = mnist.test.labels.astype(np.float32)

        train_batches = batches(batch_size, train_features, train_labels)
        # 构建可视化代码
        summary = tf.summary.merge_all()
        writer = tf.summary.FileWriter(
            logdir='./models/mnist/graph', graph=sess.graph
        )

        # c、构建循环迭代
        step = 1
        for e in range(1, epochs):
            # d、构建批量数据的迭代循环
            for batch_x, batch_y in train_batches:
                # e、执行模型优化
                train_dict = {features: batch_x, labels: batch_y, learning_rate: lr}
                sess.run(train_opt, train_dict)

                # f、打印训练 和 验证数据的 损失和准确率
                if step % 100 ==0:
                    train_loss, train_acc, summary_ = sess.run([loss, accuracy, summary], train_dict)
                    writer.add_summary(summary_, global_step=step)

                    valid_dict = {features: val_features[:512], labels: val_labels[:512]}
                    val_loss, val_acc = sess.run([loss, accuracy], valid_dict)
                    print('Epochs:{} - Step:{} - Train Loss:{:.5f} - Valid Loss:{:.5f} - Valid Acc:'
                          '{:.5f}'.format(e, step, train_loss, val_loss, val_acc))
                step += 1

            # g、执行模型持久化
            if e % 3 == 0:
                save_file_name = 'model.ckpt'
                save_file = os.path.join(checkpoint_path, save_file_name)
                saver.save(sess, save_path=save_file, global_step=e)
                print('模型成功保存至:{}'.format(save_file))

        # h、执行测试数据集,获得测试数据集的准确率。
        test_dict = {features: test_features[:512], labels: test_labels[:512]}
        test_acc = sess.run(accuracy, test_dict)
        print(' Test Acc:{:.5f}'.format(test_acc))
        writer.close()


if __name__ == '__main__':
    # batches_data = batches(
    #     batch_size=3, features=example_features, labels=example_labels
    # )
    # pprint.pprint(batches_data)
    # print('**'*40)
    # for batch_x, batch_y in get_batch(3, example_features, example_labels):
    #     print(batch_x, batch_y)
    mnist_batch_epochs()
创建文件夹:./models/mnist
Extracting ./MNIST_data/train-images-idx3-ubyte.gz
Extracting ./MNIST_data/train-labels-idx1-ubyte.gz
Extracting ./MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data/t10k-labels-idx1-ubyte.gz
Epochs:1 - Step:100 - Train Loss:1.77701 - Valid Loss:1.56862 - Valid Acc:0.51172
Epochs:1 - Step:200 - Train Loss:1.32301 - Valid Loss:1.21084 - Valid Acc:0.68359
Epochs:1 - Step:300 - Train Loss:0.89667 - Valid Loss:1.00256 - Valid Acc:0.75391
Epochs:1 - Step:400 - Train Loss:0.86192 - Valid Loss:0.88060 - Valid Acc:0.79102
Epochs:2 - Step:500 - Train Loss:0.90040 - Valid Loss:0.78734 - Valid Acc:0.81055
Epochs:2 - Step:600 - Train Loss:0.89693 - Valid Loss:0.73026 - Valid Acc:0.83008
Epochs:2 - Step:700 - Train Loss:0.75104 - Valid Loss:0.68255 - Valid Acc:0.83203
Epochs:2 - Step:800 - Train Loss:0.73570 - Valid Loss:0.64726 - Valid Acc:0.84180
Epochs:3 - Step:900 - Train Loss:0.59377 - Valid Loss:0.61561 - Valid Acc:0.84570
Epochs:3 - Step:1000 - Train Loss:0.55976 - Valid Loss:0.58995 - Valid Acc:0.85938
Epochs:3 - Step:1100 - Train Loss:0.64631 - Valid Loss:0.57156 - Valid Acc:0.86328
Epochs:3 - Step:1200 - Train Loss:0.42934 - Valid Loss:0.55611 - Valid Acc:0.86914
模型成功保存至:./models/mnist/model.ckpt
Epochs:4 - Step:1300 - Train Loss:0.51495 - Valid Loss:0.53800 - Valid Acc:0.86914
Epochs:4 - Step:1400 - Train Loss:0.55692 - Valid Loss:0.52522 - Valid Acc:0.87305
Epochs:4 - Step:1500 - Train Loss:0.56353 - Valid Loss:0.51471 - Valid Acc:0.87891
Epochs:4 - Step:1600 - Train Loss:0.39421 - Valid Loss:0.50483 - Valid Acc:0.87891
Epochs:4 - Step:1700 - Train Loss:0.50704 - Valid Loss:0.49617 - Valid Acc:0.87891
Epochs:5 - Step:1800 - Train Loss:0.58189 - Valid Loss:0.48657 - Valid Acc:0.88086
Epochs:5 - Step:1900 - Train Loss:0.34614 - Valid Loss:0.47809 - Valid Acc:0.88477
Epochs:5 - Step:2000 - Train Loss:0.39491 - Valid Loss:0.47156 - Valid Acc:0.88477
Epochs:5 - Step:2100 - Train Loss:0.59586 - Valid Loss:0.46611 - Valid Acc:0.88477
Epochs:6 - Step:2200 - Train Loss:0.38852 - Valid Loss:0.45958 - Valid Acc:0.88477
Epochs:6 - Step:2300 - Train Loss:0.39894 - Valid Loss:0.45207 - Valid Acc:0.88867
Epochs:6 - Step:2400 - Train Loss:0.38053 - Valid Loss:0.44869 - Valid Acc:0.88672
Epochs:6 - Step:2500 - Train Loss:0.43628 - Valid Loss:0.44457 - Valid Acc:0.88867
模型成功保存至:./models/mnist/model.ckpt
Epochs:7 - Step:2600 - Train Loss:0.53886 - Valid Loss:0.43915 - Valid Acc:0.89062
Epochs:7 - Step:2700 - Train Loss:0.40859 - Valid Loss:0.43430 - Valid Acc:0.89062
Epochs:7 - Step:2800 - Train Loss:0.49389 - Valid Loss:0.43099 - Valid Acc:0.89648
Epochs:7 - Step:2900 - Train Loss:0.53990 - Valid Loss:0.42923 - Valid Acc:0.89062
Epochs:7 - Step:3000 - Train Loss:0.33528 - Valid Loss:0.42447 - Valid Acc:0.89648
Epochs:8 - Step:3100 - Train Loss:0.35696 - Valid Loss:0.41999 - Valid Acc:0.89062
Epochs:8 - Step:3200 - Train Loss:0.46494 - Valid Loss:0.41574 - Valid Acc:0.89648
Epochs:8 - Step:3300 - Train Loss:0.48329 - Valid Loss:0.41509 - Valid Acc:0.89648
Epochs:8 - Step:3400 - Train Loss:0.59051 - Valid Loss:0.41300 - Valid Acc:0.89648
Epochs:9 - Step:3500 - Train Loss:0.69715 - Valid Loss:0.40948 - Valid Acc:0.89844
Epochs:9 - Step:3600 - Train Loss:0.23007 - Valid Loss:0.40518 - Valid Acc:0.89648
Epochs:9 - Step:3700 - Train Loss:0.42141 - Valid Loss:0.40386 - Valid Acc:0.89844
Epochs:9 - Step:3800 - Train Loss:0.27866 - Valid Loss:0.40251 - Valid Acc:0.89648
模型成功保存至:./models/mnist/model.ckpt
Epochs:10 - Step:3900 - Train Loss:0.62987 - Valid Loss:0.39955 - Valid Acc:0.89844
Epochs:10 - Step:4000 - Train Loss:0.44075 - Valid Loss:0.39590 - Valid Acc:0.90234
Epochs:10 - Step:4100 - Train Loss:0.41441 - Valid Loss:0.39433 - Valid Acc:0.89844
Epochs:10 - Step:4200 - Train Loss:0.40160 - Valid Loss:0.39525 - Valid Acc:0.89844
Epochs:10 - Step:4300 - Train Loss:0.46934 - Valid Loss:0.39123 - Valid Acc:0.90234
Epochs:11 - Step:4400 - Train Loss:0.52377 - Valid Loss:0.38873 - Valid Acc:0.90039
Epochs:11 - Step:4500 - Train Loss:0.43868 - Valid Loss:0.38612 - Valid Acc:0.90234
Epochs:11 - Step:4600 - Train Loss:0.22731 - Valid Loss:0.38646 - Valid Acc:0.89844
Epochs:11 - Step:4700 - Train Loss:0.30399 - Valid Loss:0.38547 - Valid Acc:0.90430
Epochs:12 - Step:4800 - Train Loss:0.37831 - Valid Loss:0.38252 - Valid Acc:0.90039
Epochs:12 - Step:4900 - Train Loss:0.59598 - Valid Loss:0.37937 - Valid Acc:0.90039
Epochs:12 - Step:5000 - Train Loss:0.41943 - Valid Loss:0.37839 - Valid Acc:0.90234
Epochs:12 - Step:5100 - Train Loss:0.45527 - Valid Loss:0.37852 - Valid Acc:0.90234
模型成功保存至:./models/mnist/model.ckpt
Epochs:13 - Step:5200 - Train Loss:0.34202 - Valid Loss:0.37693 - Valid Acc:0.90430
Epochs:13 - Step:5300 - Train Loss:0.31640 - Valid Loss:0.37332 - Valid Acc:0.90430
Epochs:13 - Step:5400 - Train Loss:0.40235 - Valid Loss:0.37344 - Valid Acc:0.90234
Epochs:13 - Step:5500 - Train Loss:0.23648 - Valid Loss:0.37402 - Valid Acc:0.90039
Epochs:14 - Step:5600 - Train Loss:0.30446 - Valid Loss:0.37053 - Valid Acc:0.90430
Epochs:14 - Step:5700 - Train Loss:0.33212 - Valid Loss:0.36973 - Valid Acc:0.90430
Epochs:14 - Step:5800 - Train Loss:0.39055 - Valid Loss:0.36759 - Valid Acc:0.90039
Epochs:14 - Step:5900 - Train Loss:0.25466 - Valid Loss:0.36874 - Valid Acc:0.90039
Epochs:14 - Step:6000 - Train Loss:0.31771 - Valid Loss:0.36710 - Valid Acc:0.90430
Epochs:15 - Step:6100 - Train Loss:0.44622 - Valid Loss:0.36535 - Valid Acc:0.90625
Epochs:15 - Step:6200 - Train Loss:0.24977 - Valid Loss:0.36221 - Valid Acc:0.90430
Epochs:15 - Step:6300 - Train Loss:0.29116 - Valid Loss:0.36254 - Valid Acc:0.90234
Epochs:15 - Step:6400 - Train Loss:0.46274 - Valid Loss:0.36255 - Valid Acc:0.90625
模型成功保存至:./models/mnist/model.ckpt
Epochs:16 - Step:6500 - Train Loss:0.27390 - Valid Loss:0.36135 - Valid Acc:0.90625
Epochs:16 - Step:6600 - Train Loss:0.29047 - Valid Loss:0.35821 - Valid Acc:0.90430
Epochs:16 - Step:6700 - Train Loss:0.29196 - Valid Loss:0.35851 - Valid Acc:0.90430
Epochs:16 - Step:6800 - Train Loss:0.32041 - Valid Loss:0.35856 - Valid Acc:0.90430
Epochs:17 - Step:6900 - Train Loss:0.40900 - Valid Loss:0.35714 - Valid Acc:0.90625
Epochs:17 - Step:7000 - Train Loss:0.29734 - Valid Loss:0.35540 - Valid Acc:0.90430
Epochs:17 - Step:7100 - Train Loss:0.37197 - Valid Loss:0.35458 - Valid Acc:0.90625
Epochs:17 - Step:7200 - Train Loss:0.46864 - Valid Loss:0.35597 - Valid Acc:0.90430
Epochs:17 - Step:7300 - Train Loss:0.25809 - Valid Loss:0.35376 - Valid Acc:0.90625
Epochs:18 - Step:7400 - Train Loss:0.30165 - Valid Loss:0.35189 - Valid Acc:0.90430
Epochs:18 - Step:7500 - Train Loss:0.39663 - Valid Loss:0.34958 - Valid Acc:0.90820
Epochs:18 - Step:7600 - Train Loss:0.41498 - Valid Loss:0.35148 - Valid Acc:0.90430
Epochs:18 - Step:7700 - Train Loss:0.52507 - Valid Loss:0.35132 - Valid Acc:0.90430
模型成功保存至:./models/mnist/model.ckpt
Epochs:19 - Step:7800 - Train Loss:0.65190 - Valid Loss:0.34996 - Valid Acc:0.90430
Epochs:19 - Step:7900 - Train Loss:0.17400 - Valid Loss:0.34745 - Valid Acc:0.90039
Epochs:19 - Step:8000 - Train Loss:0.34953 - Valid Loss:0.34751 - Valid Acc:0.90039
Epochs:19 - Step:8100 - Train Loss:0.22821 - Valid Loss:0.34797 - Valid Acc:0.90234
 Test Acc:0.90234

Process finished with exit code 0

可视化代码:

tensorboard --logdir/home/hjz/PycharmProjects/lianxi/11_tensorflow1.4.0/models/mnist/graph

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值