mnist

该博客使用TensorFlow搭建BP神经网络进行数字识别。设置变量后,搭建计算网络,采用relu函数作为激励函数,交叉熵损失函数。经过多次迭代计算,打印训练正确率,还对测试集输出结果可视化,并验证了手写图片的识别效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import os ##
os.environ[‘TF_CPP_MIN_LOG_LEVEL’] = ‘2’ ##
import tensorflow as tf
import urllib
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter

old_v = tf.logging.get_verbosity() ##
tf.logging.set_verbosity(tf.logging.ERROR) ##
mnist = input_data.read_data_sets(“MNIST_data/”, one_hot = True)

#建立BP神经网络模型
num_classes = 10#数据类型0-9
input_size = 784#28*28
hidden_units_size = 30#层节点数
batch_size = 100#
training_iterations = 50000#迭代次数

设置变量

X = tf.placeholder (tf.float32, shape = [None, input_size]) #将一个二维图像展平后,放入一个长度为784的数组中。
Y = tf.placeholder (tf.float32, shape = [None, num_classes]) #输出结果是0-9的数字,所以只有10种结构。
W1 = tf.Variable (tf.random_normal ([input_size, hidden_units_size],
stddev = 0.1))#hidden_units_size = 30#正态分布随机数
B1 = tf.Variable (tf.constant (0.1),
[hidden_units_size])#常数为1,形状为(1,1)
W2 = tf.Variable (tf.random_normal ([hidden_units_size,
num_classes], stddev = 0.1))#正态分布随机数
B2 = tf.Variable (tf.constant (0.1), [num_classes])

搭建计算网络 使用 relu 函数作为激励函数 这个函数就是 y = max (0,x) 的一个类似线性函数 拟合程度还是不错的

使用交叉熵损失函数 这是分类问题例如 : 神经网络 对率回归经常使用的一个损失函数

#第1层神经网络
hidden_opt = tf.matmul (X, W1) + B1#矩阵运算
hidden_opt = tf.nn.relu (hidden_opt)#激活函数
#第2层神经网络
final_opt = tf.matmul (hidden_opt, W2) + B2#矩阵运算
final_opt = tf.nn.relu (final_opt)#激活函数,最终的输出结果
loss = tf.reduce_mean (
tf.nn.softmax_cross_entropy_with_logits (labels = Y, logits = final_opt))#损失函数,交叉熵方法
opt = tf.train.GradientDescentOptimizer (0.1).minimize (loss)
init = tf.global_variables_initializer ()#全局变量初始化
correct_prediction = tf.equal (tf.argmax (Y, 1), tf.argmax (final_opt, 1))
accuracy = tf.reduce_mean (tf.cast (correct_prediction, ‘float’))#将张量转化成float

进行计算 打印正确率

sess = tf.Session ()#生成能进行TensorFlow计算的类
sess.run (init)
for i in range (training_iterations) :
batch = mnist.train.next_batch (batch_size)#每次迭代选用的样本数100
batch_input = batch[0]
batch_labels = batch[1]
training_loss = sess.run ([opt, loss], feed_dict = {X: batch_input, Y: batch_labels})
if (i+1) % 10000 == 0 :
train_accuracy = accuracy.eval (session = sess, feed_dict = {X: batch_input,Y: batch_labels})
print ("step : %d, training accuracy = %g " % (i+1, train_accuracy))

###测试集输出结果可视化
def res_Visual(n):
#sess=tf.Session()
#sess.run(tf.global_variables_initializer())
final_opt_a=tf.argmax (final_opt, 1).eval(session=sess,feed_dict = {X: mnist.test.images,Y: mnist.test.labels})
fig, ax = plt.subplots(nrows=int(n/5),ncols=5 )
ax = ax.flatten()
print(‘前{}张图片预测结果为:’.format(n))
for i in range(n):
print(final_opt_a[i],end=’,’)
if int((i+1)%5) ==0:
print(’\t’)
#图片可视化展示
img = mnist.test.images[i].reshape((28,28))#读取每行数据,格式为Ndarry
ax[i].imshow(img, cmap=‘Greys’, interpolation=‘nearest’)#可视化
print(‘测试集前{}张图片为:’.format(n))
plt.show() ##
res_Visual(20)

#验证自己手写图片的识别效果
#导入图片,二值化,并输出模型可识别的格式
def image_to_number(n):
from PIL import Image
import numpy as np
fig, ax = plt.subplots(nrows=int(n/5),ncols=5 )
ax = ax.flatten()
image_test = []
label_test = np.zeros((n,10))#手写图片的lebel
for i in range(n):
label_test[i][i] =1#将(0,0)(1,1)等位置赋值为1
line = []
img = Image.open("{}.png".format(i)) # 打开一个图片,并返回图片对象
img = img.convert(‘L’) # 转换为灰度,img.show()可查看图片
img = img.resize((28,28)) # 将图片重新以(w,h)尺寸存储
for y in range(28):
for x in range(28):
line.append((255-img.getpixel((x,y)))/255)# getpixel 获取该位置的像素信息
image_test.append(line)#存储像素点信息
line = np.array(line)#转化为np.array
ax[i].imshow(line.reshape(28,28), cmap=‘Greys’, interpolation=‘nearest’)
#plt.imshow(line.reshape(28,28), cmap=‘Greys’)#显示图片,imshow能够将数字转换为灰度显示出图像
image_test = np.array(image_test)
plt.show() ##
return image_test,label_test
image_test,label_test = image_to_number(10)
plt.show() ##

tf.logging.set_verbosity(old_v) ##

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值