【tensorflow入门】14、卷积神经网络CNN

本文介绍使用卷积神经网络(CNN)解决MNIST手写数字识别问题的方法,通过调整神经元数量和引入dropout避免过拟合,最终在测试集上达到93%以上的识别准确率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文代码使用CNN实现MNIST手写数字识别问题,并统计准确率。

 

本次几乎没踩雷,只是起初搭建的全连接层中神经元个数太多,cpu带不起来,出现一片黄色警告。

心得:

可改代码把数据减少。其实也该研究下gpu的使用了。

代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# number 1 to 10 data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def compute_accuracy(v_xs, v_ys):  #依然是输出准确度 百分之多少
    global prediction
    y_pre = sess.run(prediction, feed_dict={xs: v_xs, keep_prob: 1})
    correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys, keep_prob: 1})
    return result  

def weight_variable(shape):  #输入shape,返回variable定义的一些参数
    initial = tf.truncated_normal(shape,stddev=0.1)
    return tf.Variable(initial)
 
def bias_variable(shape):  #bias一般是正值比较好,所以0.1
    initial =tf.constant(0.1,shape=shape) #初始0.1,之后从0.1变成其他的值。
    return tf.Variable(initial)

def conv2d(x, W): #convolutional图层  x:输入的值(图片之类的) W:Weights
    #stride[1,x_movement,y_movement,1]
    return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME') 
                        #2维的convolutional neural network  x:图片的所有信息
                                        #strides:步长 4长度的列表 [1,,,1]是规定死的


def max_pool_2x2(x):
    #stride[1,x_movement,y_movement,1]
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')  
                                #实际上就是conv2d中传出的东西再传入max_pooling中去,区别:不需要W
                                #strides的第0位、第3位必须为1.
                                #每2个步长移动一下,实现图片的压缩

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 784]) # 28x28
ys = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)

#定义层之前,先处理下传入的信息,换成另外一种形式
x_image = tf.reshape(xs,[-1,28,28,1])
    #因为x_data中包含了所有的sample(例子)  -1:说明维度先不管它  28*28:像素点  1:channel(黑白图为1,彩色为3)
    #print(x_image.shape) #[n_samples,28,28,1]


## conv1 layer ##
W_conv1 = weight_variable([5,5,1,32])  #patch 5*5   1:in_size   2:out_size
              #小方块patch的长和宽5*5像素,1:输出一个单位的结果,它的高度为32
b_conv1 = bias_variable([32])  #bias只有32的长度
#开始搭建CNN第一层
h_conv1 = tf.nn.relu(conv2d(x_image,W_conv1) + b_conv1)  #其实就是用conv2d来处理Wx_plus_b的关系
                    #加个relu非线性处理,让它非线性化
                    #output size 28*28*32 因为卷积时是以padding=SAME的形式抽出,所以图片长宽不变,只改变了高度,所以28*28
                    #因为卷积[5,5,1,32] ,所以高度变为32
h_pool1 = max_pool_2x2(h_conv1) #output size 14*14*32 因为conv2d的步长[1,1,1,1],而pooling时步长[1,2,2,1]
                    #说明整个图像的长宽缩小了一倍,即除以2


## conv2 layer ##
W_conv2 = weight_variable([5,5,32,64])  #传入变为32,传出假设64,因为要不断变高变厚。
                   #图片原本高度1(黑白),经过convolution1,厚度变为32,这次变为64
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1,W_conv2) + b_conv2)  #output:14*14*64
h_pool2 = max_pool_2x2(h_conv2)                          #output:7*7*64 pooling时高度不变

## func1 layer ##
W_fc1 =weight_variable([7*7*64,256])  #function1 它的输入的形状shape是conv2 layer输出的形状
                               #输出1024,让它变得更宽?高?
b_fc1 = bias_variable([256])
#三维转一维
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])  #[n_samples,7,7,64]这个shape变到[n_samples,7*7*64]
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1)+b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)


## func2 layer ##
W_fc2 =weight_variable([256,10])  #传入1024 传出10(0-9)
b_fc2 = bias_variable([10]) 
prediction = tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2) #用softmax算概率



# the error between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
                                              reduction_indices=[1]))       # loss
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) #对于庞大的计算用AdamOptimizer比较好
               #AdamOptimizer需要更小的learning rate,故使用0.0001
sess = tf.Session()
# important step
sess.run(tf.global_variables_initializer())

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(50)
    sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})
    if i % 50 == 0:
        print(compute_accuracy(
            mnist.test.images, mnist.test.labels))

 运行结果:

WARNING:tensorflow:From <ipython-input-1-6a24db2fd0d1>:9: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From d:\Users\aDreamer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From d:\Users\aDreamer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-images-idx3-ubyte.gz
WARNING:tensorflow:From d:\Users\aDreamer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From d:\Users\aDreamer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From d:\Users\aDreamer\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
0.0832
0.5769
0.7864
0.826
0.8638
0.8842
0.904
0.9084
0.9211
0.9194
0.9254
0.929
0.9306
0.9364时间关系运行到这我就中止程序了。肉眼可见准确率在提升。

因为本代码计算量较庞大,运行太慢,故我将输出内容全部粘贴过来,以便将来查阅。

本文有很多警告warning,主要是因为其中部分包在将来的tensorflow版本中会被停止使用。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值