CNN学习(三)—Tensorflow 进行MNIST手写体识别

本文构建了一个六层的卷积神经网络(CNN)并应用于MNIST手写数字数据集的训练中,通过详细的代码实现展示了从搭建网络结构、定义损失函数到训练过程的全部细节。

前言

本节,我们牛刀小试一下,使用Tensor的构建一个简单的六层CNN网络来对MNIST手写体数据集进行训练。

网络结构:

这里写图片描述

代码

__author__ = 'jmh081701'
#coding:utf-8
import  tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow.examples.tutorials.mnist import input_data
#载入MNIST数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
#mnist.train.images是一个列表:shape:(55000,784),在输入前需要先转换为con2v的input参数个数的形式
#mnist.train.labels标签,也是一个列表,shape:(55000,10)
#mnist.test.images是测试集的

imageX=tf.placeholder(dtype=tf.float32,shape=[None,784])
#imageX是训练是的输入图像
labelY=tf.placeholder(dtype=tf.float32,shape=[None,10])
#keep_prob=tf.placeholder(tf.float32)
#lableY是训练时图像对应的标签。shape第一个参数为-1意为具体样本数待定
with tf.name_scope('C1'):
    W_C1=tf.Variable(tf.truncated_normal([5,5,1,32],stddev=0.1),dtype=tf.float32)
    b_C1=tf.Variable(tf.constant(0.1,tf.float32,shape=[32]))
    #W_C1是C1层的权值矩阵,它也是卷积核,共有10个卷积核。
    # b_C1则是偏置
    X=tf.reshape(imageX,[-1,28,28,1])
    #需要对输入转化为conv2d想要的格式
    featureMap_C1=tf.nn.conv2d(X,W_C1,[1,1,1,1],padding='SAME')+b_C1
    #conv2d的参数:
    #input:[图片个数,图片长,图片宽,图片的通道数]
    #filter:[滤波器长,滤波器宽,输入通道数,输出通道数]
    #stride:[1,1,1,1] 在四个轴上跳跃的大小
    #OK,C1卷积完成

with tf.name_scope('f'):
    relu_C1=tf.nn.relu(featureMap_C1)  #激活层
with tf.name_scope('S2'):
    featureMap_S2=tf.nn.max_pool(relu_C1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    #S2的池化。
with tf.name_scope('C3'):
    W_C3=tf.Variable(tf.truncated_normal([5,5,32,64],stddev=0.1))
    b_C3=tf.Variable(tf.constant(0.1,tf.float32,shape=[64]))
    featureMap_C3=tf.nn.conv2d(featureMap_S2,W_C3,[1,1,1,1],padding='SAME')+b_C3

with tf.name_scope('f'):
    relu_C3=tf.nn.relu(featureMap_C3)
with tf.name_scope('S4'):
    featureMap_S4=tf.nn.max_pool(relu_C3,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
#C3以及S4的过程
with tf.name_scope('flat'):
    fetureMap_flatter=tf.reshape(featureMap_S4,[-1,7*7*64])
#栅格化
with tf.name_scope('fullcont'):
    W_F5=tf.Variable(tf.truncated_normal([7*7*64,1024],stddev=0.1))
    b_F5=tf.Variable(tf.constant(0.1,tf.float32,shape=[1024]))
    out_F5=tf.nn.relu(tf.matmul(fetureMap_flatter,W_F5)+b_F5)
    #out_F5_drop=tf.nn.dropout(out_F5,keep_prob)
#全连接层完成
with tf.name_scope('output'):
    W_OUTPUT=tf.Variable(tf.truncated_normal([1024,10],stddev=0.1))
    b_OUTPUT=tf.Variable(tf.constant(0.1,tf.float32,shape=[10]))
    predictY=tf.nn.softmax(tf.matmul(out_F5,W_OUTPUT)+b_OUTPUT)

#输出层,使用softmax函数

loss=tf.reduce_mean(-tf.reduce_sum(labelY*tf.log(predictY)))
tf.summary.histogram('loss',loss)
tf.summary.scalar('loss',loss)
#残差函数loss设置为交叉熵
learning_rate=1e-4
#train_op=tf.train.AdamOptimizer(learning_rate).minimize(loss)
train_op=tf.train.AdamOptimizer(learning_rate).minimize(loss)

#设置训练方法,采用Ada最优化方法
y_pred=tf.arg_max(predictY,1)
bool_pred=tf.equal(tf.arg_max(labelY,1),y_pred)

right_rate=tf.reduce_mean(tf.to_float(bool_pred))

#检查错误率
saver=tf.train.Saver()
def load_model(sess,dir,modelname):
    ckpt=tf.train.get_checkpoint_state(dir)
    if ckpt and ckpt.model_checkpoint_path:
        print("*"*30)
        print("load lastest model......")
        saver.restore(sess,dir+".\\"+modelname)
        print("*"*30)

def save_model(sess,dir,modelname):
    saver.save(sess,dir+modelname)
dir=r"C:\\Users\\jmh081701\\Documents\\TempWorkStation\\python\\tensorflow\\cnnmodel\\"
modelname="cnnmodel"
with tf.Session() as sess:
    init =tf.global_variables_initializer()
    sess.run(init)
    step=1
    sameMAX=10
    sameStep=0
    accSum=0
    batch_epoch=int(mnist.train.num_examples/100)
    load_model(sess,dir,modelname)

    writer=tf.summary.FileWriter(".//cnngrahph",tf.get_default_graph())
    merged=tf.summary.merge_all()
    while True:
        if(step%batch_epoch==0):
            #测试一下
            test_img,test_lab=mnist.test.next_batch(1000)
            #test_img=mnist.test.images
            #test_lab=mnist.test.labels
            acc=sess.run(right_rate,{imageX:test_img,labelY:test_lab})
            tf.summary.histogram('rightrate',acc)
            print({"!!!!!!!!!!!!!!testing:"+str(step):acc})
            accSum=accSum+acc
            if(sameStep%sameMAX==0):
                if(acc==accSum/sameMAX):
                    print({step:acc})
                    break;
                else:
                    accSum=0
                    sameStep=0
            step=step+1
            save_model(sess,dir,modelname)
            continue
        image_batch,label_batch=mnist.train.next_batch(100)
        l,op,summary=sess.run([loss,train_op,merged],feed_dict={imageX:image_batch,labelY:label_batch})
        print(step,l)
        if(step%20==0):
            #每隔20批,跟踪一次
            writer.add_summary(summary,step)
        step=step+1
    save_model(sess,dir,modelname)
    print(sess.run([right_rate],feed_dict={imageX:mnist.test.images,labelY:mnist.test.labels}))
#注意点!在随机化参数的时候,不能标准差不能太大,否则很容易在计算log时出现nan

#注意点!在随机化参数的时候,不能标准差不能太大,否则很容易在计算log时出现nan

运行结果

···
547 17.2985
548 10.7501
549 22.6312
{'!!!!!!!!!!!!!!testing:550': 0.95200002}
551 14.7252
552 6.515
553 8.9048
554 7.67524
···
1098 11.3033
1099 8.11119
{'!!!!!!!!!!!!!!testing:1100': 0.977}
1101 8.80489
1102 6.89097
···
···
第一轮跑完后的正确率就有95.2%了,
第二轮跑完就有97.7%的正确率了
到第10轮就有99.0%以上

Tensorboard分析

我们使用tensorboard进行可视化追踪
这里写图片描述
这个是交叉熵的变化情况,在第一轮完成时(大概第550批左右),loss基本稳定在10.0.之后随着训练次数的增多,loss越来越低。
这里写图片描述
这是Tensorboard追踪到的图结构。

<think>好的,用户想查找使用TensorFlow实现MNIST手写体识别的教程或代码示例。首先,我需要回顾一下现有的引用内容,看看哪些资料可以帮助回答这个问题。 引用[1]提到的是一个Android应用示例,使用TensorFlow Lite进行MNIST识别,可能不是用户需要的完整TensorFlow教程,但可以作为移动端部署的参考。引用[2]和[4]都提到了基础的MNIST识别示例,其中引用[4]详细介绍了使用TensorFlow的实现,包括环境搭建、代码示例和实际应用思路,这应该是核心参考资料。引用[3]和[5]则分别涉及CNN模型和tf.data的使用,可能用户对更高级的实现感兴趣,比如卷积神经网络或者数据管道的优化。引用[5]中的tf.data示例对数据处理有帮助,这可能也是用户需要的部分。 用户的问题比较广泛,可能需要分层次回答。首先应该提供基础的实现步骤和代码,然后介绍更高级的方法如CNN,再提到数据处理的优化,最后补充部署到移动端的可能性。需要确保代码示例符合用户的要求,并且引用正确的来源。同时,根据用户的历史引用,他们可能已经了解了一些TensorFlow的基础,所以可以适当深入,比如提到不同的模型结构或优化技巧。 另外,用户可能需要教程链接或代码结构,因此需要整理引用中的相关信息,比如Github链接、文档地址等。同时,注意回答的结构要清晰,分步骤说明,让用户能够逐步跟随。还需要检查是否有遗漏的部分,例如数据预处理、模型训练、评估和预测等环节是否都涵盖到了。 最后,生成的相关问题需要覆盖用户可能进一步询问的方向,比如不同模型结构、性能优化、部署应用等,这样可以帮助用户深入探索。要确保每个回答的部分都有对应的引用支持,比如引用[3]提到CNN模型,引用[5]提到tf.data的使用,引用[1]关于移动端部署,这些都需要在回答中正确引用。</think>以下是为您整理的TensorFlow实现MNIST手写数字识别的分步指南及代码示例: ### 一、基础实现步骤 1. **环境搭建** - Python 3.7+ + TensorFlow 2.x - 安装命令:`pip install tensorflow`[^2][^3] 2. **数据加载** ```python import tensorflow as tf mnist = tf.keras.datasets.mnist (train_images, train_labels), (test_images, test_labels) = mnist.load_data() train_images, test_images = train_images/255.0, test_images/255.0 # 归一化 ``` 3. **全连接神经网络模型** ```python model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) ``` 4. **训练与评估** ```python model.fit(train_images, train_labels, epochs=5) model.evaluate(test_images, test_labels) ``` 基础模型准确率可达约98%[^2][^4] ### 二、进阶CNN实现 ```python # 数据维度扩展 train_images = train_images[..., tf.newaxis] test_images = test_images[..., tf.newaxis] # CNN模型 model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)), tf.keras.layers.MaxPooling2D((2,2)), tf.keras.layers.Conv2D(64, (3,3), activation='relu'), tf.keras.layers.MaxPooling2D((2,2)), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10) ]) model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) ``` CNN模型准确率可达99%+[^3][^4] ### 、代码资源推荐 1. **官方示例**:TensorFlow官网MNIST教程 https://tensorflow.google.cn/tutorials/quickstart/beginner[^2] 2. **项目结构示例** ``` project/ ├── train.py # 训练脚本 ├── predict.py # 预测脚本 ├── test_images/ # 测试图片 └── checkpoints/ # 模型保存目录[^3] ``` 3. **数据管道优化** ```python # 使用tf.data提升性能 dataset = tf.data.Dataset.from_tensor_slices( (train_images, train_labels)).shuffle(1000).batch(32)[^5] ``` ### 四、移动端部署 通过TensorFlow Lite可将训练好的模型转换为`.tflite`格式,部署到Android设备实现实时识别[^1] 相关问题
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值