用Tensorflow搭建CNN卷积神经网络,实现MNIST手写数字识别

本文详细介绍了如何使用Tensorflow构建CNN,从本地加载并预处理MNIST数据,创建批量训练,构建网络结构,训练并评估模型,最后还提供了保存和加载模型的方法。涉及的第三方库包括Tensorflow、Pillow、NumPy和Matplotlib。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

写在前面的话

不同于Tensorflow官方教程简略的DEMO,我们自己动手实现以下目标
- 从本地文件系统中加载图片、标签
- 对图片和标签预处理
- 创建batch对象以提供随机批次训练
- 构建网络结构
- 训练神经网络
- 在验证集合上评估准确率
- 保存及加载网络参数模型

CNN卷积神经网络的基础知识及简介,我推荐这篇文章。
http://brohrer.github.io/how_convolutional_neural_networks_work.html

训练数据

请下载 https://pan.baidu.com/s/1cdBnbC
训练集合
train.txt 图片文件名-标签
train.rar 图片库,请解压为train文件夹
验证集合
val.txt 图片名-标签
val.rar 图片库,请解压为val文件夹

本程序用到的第三方库

  • tensorflow (1.0.1) 谷歌基于DistBelief进行研发的第二代人工智能学习系统
  • Pillow (3.2.0) 基本的图像处理功能
  • numpy (1.12.1) 开源的数值计算扩展
  • matplotlib (1.5.1) Python 的 2D绘图库

干货,代码及讲解

定义输入数据路径

TRAIN_LABEL_PATH = "/path/to/train.txt"
TRAIN_IMAGE_PATH = "/path/to/train/"
VAL_LABEL_PATH = "/path/to/val.txt"
VAL_IMAGE_PATH = "/path/to/val/"

定义函数辅助构建神经网络

def weigth_variable(shape, name):
    # 这里使用截断的正态分布,标准差为0.1
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial, name = name)

def bias_variable(shape):
    # bias初始化为0.1避免死亡节点
    initial = tf.constant(0.1, shape=shape)
    return initial

def conv2d(x, W):
    # 参数中x是输入,W是卷积的参数,比如[5,5,1,32]:前面两个数字代表卷积核的尺寸;第三个数字代表有多少个channel。因为我们只有灰度单色,所以是1,如果是RGB彩色图片,这里应该是3。
    # 最后一个数字代表卷积核的数量,也就是这个卷积层会提取多少类的特征。
    # Strides代表卷积模板移动的步长,都是1代表会不遗漏地划过图片的每一个点。
    # Padding代表边界的处理方式,这里的SAME代表给边界加上Padding让卷积的输出和输入保持同样(SAME)是尺寸。
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME")

def
利用tensorflow实现卷积神经网络来进行MNIST手写数字图像的分类。 #导入numpy模块 import numpy as np #导入tensorflow模块,程序使用tensorflow实现卷积神经网络 import tensorflow as tf #下载mnist数据集,并从mnist_data目录中读取数据 from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('mnist_data',one_hot=True) #(1)这里的“mnist_data” 是和当前文件相同目录下的一个文件夹。自己先手工建立这个文件夹,然后从https://yann.lecun.com/exdb/mnist/ 下载所需的4个文件(即该网址中第三段“Four files are available on this site:”后面的四个文件),并放到目录MNIST_data下即可。 #(2)MNIST数据集是手写数字字符的数据集。每个样本都是一张28*28像素的灰度手写数字图片。 #(3)one_hot表示独热编码,其值被设为true。在分类问题的数据集标注时,如何不采用独热编码的方式, 类别通常就是一个符号而已,比如说是9。但如果采用独热编码的方式,则每个类表示为一个列表list,共计有10个数值,但只有一个为1,其余均为0。例如,“9”的独热编码可以为[00000 00001]. #定义输入数据x和输出y的形状。函数tf.placeholder的目的是定义输入,可以理解为采用占位符进行占位。 #None这个位置的参数在这里被用于表示样本的个数,而由于样本个数此时具体是多少还无法确定,所以这设为None。而每个输入样本的特征数目是确定的,即为28*28。 input_x = tf.placeholder(tf.float32,[None,28*28])/255 #因为每个像素的取值范围是 0~255 output_y = tf.placeholder(tf.int32,[None,10]) #10表示10个类别 #输入层的输入数据input_x被reshape成四维数据,其中第一维的数据代表了图片数量 input_x_images = tf.reshape(input_x,[-1,28,28,1]) test_x = mnist.test.images[:3000] #读取测试集图片的特征,读取3000个图片 test_y = mnist.test.labels[:3000] #读取测试集图片的标签。就是这3000个图片所对应的标签
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值