25-mnist08_3layers

本文介绍了一个基于TensorFlow的手写数字识别系统,使用MNIST数据集进行训练和验证。模型采用全连接神经网络,包含两个隐藏层,通过Adam优化器进行参数更新,实现了对手写数字的有效识别。
部署运行你感兴趣的模型镜像
from tensorflow.examples.tutorials.mnist.input_data import read_data_sets
import numpy as np
import cv2
import tensorflow as tf


class Config:
    def __init__(self):
        self.sample_path = '../deeplearning_ai12/p07_mnist/MNIST_data'
        self.lr = 0.001
        self.epoches = 200
        self.batch_size = 50
        self.eps = 1e-10
        self.base_filters = 16  # should be 32 at least

        self.name = 'mnist08'
        self.save_path = '../models/{name}/{name}'.format(name=self.name)
        self.logdir = '../logs/{name}'.format(name=self.name)


class Tensors:
    def __init__(self, config: Config):
        self.config = config
        self.x = tf.placeholder(tf.float32, [None, 784], 'x')
        logits = self.get_logits(self.x)  # [-1, 10]
        self.y_predict = tf.argmax(logits, axis=1, output_type=tf.int32)  # [-1]

        self.y = tf.placeholder(tf.int32, [None], 'y')
        y = tf.one_hot(self.y, 10)  # [-1, 10]

        self.loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits)
        self.loss = tf.reduce_mean(self.loss)
        opt = tf.train.AdamOptimizer(config.lr)
        self.train_op = opt.minimize(self.loss)

        self.precise = tf.reduce_mean(tf.cast(tf.equal(self.y, self.y_predict), tf.float32))

        params = 0
        for var in tf.trainable_variables():
            ps = _params(var.shape)
            print(var.name, var.shape, ps)
            params += ps
        print('-' * 200)
        print('Total:', params)

        tf.summary.scalar('loss', self.loss)
        tf.summary.scalar('precise', self.precise)
        self.summary_op = tf.summary.merge_all()

    def get_logits(self, x):
        """

        :param x: [-1, 784]
        :return: [-1, 10]
        """
        x = tf.layers.dense(x, 4000, tf.nn.relu, name='dense1')
        x = tf.layers.dense(x, 10, name='dense2')
        return x


def _params(shape):
    result = 1
    for sh in shape:
        result *= sh.value
    return result


class Samples:
    def __init__(self, config):
        ds = read_data_sets(config.sample_path)

        self.train = SubSamples(ds.train)
        self.validation = SubSamples(ds.validation)
        self.test = SubSamples(ds.test)


class SubSamples:
    def __init__(self, data):
        self.data = data

    def num_examples(self):
        return self.data.num_examples

    def next_batch(self, batch_size):
        return self.data.next_batch(batch_size)  # xs: [batch_size, 784], ys: [batch_size]


def show_imgs(xs, ys):
    print(ys)
    xs = np.reshape(xs, [-1, 28, 28])
    xs = np.transpose(xs, [1, 0, 2])  # [28, -1, 28]
    xs = np.reshape(xs, [28, -1, 28 * 20])  # [28, -1, 560],
    xs = np.transpose(xs, [1, 0, 2])  # [-1, 28, 560]
    xs = np.reshape(xs, [-1, 28 * 20])

    cv2.imshow('My digits', xs)
    cv2.waitKey()


class App:
    def __init__(self, config: Config):
        self.config = config
        self.samples = Samples(config)

        g = tf.Graph()
        with g.as_default():
            self.tensors = Tensors(config)
            self.session = tf.Session(graph=g)
            self.saver = tf.train.Saver()

            try:
                self.saver.restore(self.session, config.save_path)
                print('Restore the model from %s successfully' % config.save_path)
            except:
                print('Fail to restore the model from %s, use a new model instead' % config.save_path)
                self.session.run(tf.global_variables_initializer())

    def close(self):
        self.session.close()

    def train(self):
        train_samples = self.samples.train
        config = self.config
        ts = self.tensors

        fw = tf.summary.FileWriter(config.logdir, self.session.graph)
        step = 0
        for epoch in range(config.epoches):
            batches = train_samples.num_examples() // config.batch_size
            for batch in range(batches):
                xs, ys = train_samples.next_batch(config.batch_size)
                _, summary = self.session.run([ts.train_op, ts.summary_op], {ts.x: xs, ts.y: ys})
                fw.add_summary(summary, step)
                step += 1

                xs, ys = self.samples.validation.next_batch(config.batch_size)
                precise_v = self.session.run(ts.precise, {ts.x: xs, ts.y: ys})

                print('Epoch: %d, batch %d: precise=%.6f' % (epoch, batch, precise_v))
            self.saver.save(self.session, config.save_path)
            print('Model saved into', config.save_path)
        print('Training is finished!')

    def predict(self):
        xs, ys = self.samples.test.next_batch(self.config.batch_size)
        print(ys)

        ts = self.tensors
        ys_predict = self.session.run(ts.y_predict, {ts.x: xs})
        print('predict:')
        print(ys_predict)


if __name__ == '__main__':
    config = Config()
    app = App(config)

    app.train()
    # app.predict()
    app.close()

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值