TensorFlow神经网络(五)输入手写数字图片进行识别

本文详细介绍了使用TensorFlow实现手写数字识别的过程,包括断点续训策略、模型训练及预测流程。通过实际案例,展示了如何处理输入图片,进行预处理并识别手写数字,揭示了常见错误及解决方案。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、断点续训

为防止突然断电、参数白跑的情况发生,在backward中加入类似于之前test中加载ckpt的操作,给所有w和b赋保存在ckpt中的值:
1. 如果存储断点文件的目录文件夹中,包含有效断点状态文件,则返回该文件:

  • 参数说明
    checkpoint_dir: 表示存储断点文件的目录
    latest_filename: 断点文件的可选名称,默认为checkpoint
ckpt = tf.train.get_checkpoint_state(checkpoint_dir,\
 latest_filename = None)

2. 如果ckpt存在,且保存的模型在指定路径中存在

	if ckpt and ckpt.model_checkpoint_path: 

3. 恢复当前会话,将ckpt中的值赋给 w 和 b

  • 参数说明

sess:表示当前会话,之前保存的结果会被加载入这个会话
ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,因为有了位置会自动去查看checkpoint文件,看最新的模型叫什么

	saver.restore(sess, ckpt.model_checkpoint_path) 

4. 完整代码:

# 断点续训 breakpoint_continue.py
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path: 
# 恢复当前会话,将ckpt中的值赋给 w 和 b		
		saver.restore(sess, ckpt.model_checkpoint_path) 

反向传播的with结构中加入加载ckpt的操作后:

# mnist_backward.py
# coding: utf-8

import tensorflow as tf
# 导入imput_data模块
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

# 定义超参数
BATCH_SIZE = 200

LEARNING_RATE_BASE = 0.1 #初始学习率
LEARNING_RATE_DECAY = 0.99 # 学习率衰减率
REGULARIZER = 0.0001 # 正则化参数

STEPS = 50000 #训练轮数

MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"
def backward(mnist):
	# placeholder占位
	x = tf.placeholder(tf.float32, shape = (None, mnist_forward.INPUT_NODE))
	y_ = tf.placeholder(tf.float32, shape = (None, mnist_forward.OUTPUT_NODE))

	# 前向传播推测输出y
	y = mnist_forward.forward(x, REGULARIZER)

	# 定义global_step轮数计数器,定义为不可训练
	global_step = tf.Variable(0, trainable = False)

	# 包含正则化的损失函数
	# 交叉熵
	ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, \
	labels = tf.argmax(y_, 1))
	cem = tf.reduce_mean(ce)
	# 使用正则化时的损失函数
	loss = cem + tf.add_n(tf.get_collection('losses'))

	# 定义指数衰减学习率
	learning_rate = tf.train.exponential_decay(
		LEARNING_RATE_BASE,
		global_step, 
		mnist.train.num_examples/BATCH_SIZE,
		LEARNING_RATE_DECAY,
		staircase = True)

	# 定义反向传播方法:包含正则化
	train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,\
	 global_step = global_step)

	# 定义滑动平均时,加上:
	ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
	ema_op = ema.apply(tf.trainable_variables())
	with tf.control_dependencies([train_step, ema_op]):
		train_op = tf.no_op(name = 'train')
	
	# 实例化saver
	saver = tf.train.Saver()

	# 训练过程
	with tf.Session() as sess:
		# 初始化所有参数
		init_op = tf.global_variables_initializer()
		sess.run(init_op)

        # 断点续训 breakpoint_continue.py
		ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
		if ckpt and ckpt.model_checkpoint_path: 
		# 恢复当前会话,将ckpt中的值赋给 w 和 b		
				saver.restore(sess, ckpt.model_checkpoint_path) 

		# 循环迭代
		for i in range(STEPS):
			# 将训练集中一定batchsize的数据和标签赋给左边的变量
			xs, ys = mnist.train.next_batch(BATCH_SIZE)
			# 喂入神经网络,执行训练过程train_step
			_, loss_value, step = sess.run([train_op, loss, global_step], \
			feed_dict = {x: xs, y_: ys})

			if i % 1000 == 0: # 拼接成./MODEL_SAVE_PATH/MODEL_NAME-global_step路径
			# 打印提示
				print("after %d steps, loss on traing batch is %g" %(step, loss_value))
				saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), \
				global_step = global_step)
			

def main():
	mnist = input_data.read_data_sets('./data/', one_hot = True)
	# 调用定义好的测试函数
	backward(mnist)
# 判断python运行文件是否为主文件,如果是,则执行
if __name__ == '__main__':
	main()


结果发现模型自动接着之前开机的结束的50000次开始往后训练了,在两次ctrl+C之后,再重新执行时仍可以从断点继续:
在这里插入图片描述


二、如何对输入的手写数字图片,输出正确预测结果

  • 除了minist_forward, mnist_backward, mnist_test之外,增加mnist_app.py一个py文件

自己遇到的问题之
(一)main函数没有写对

if __name__ == '__main__':
	main()

写成了

if __name__ == 'main':
	main()

结果代码根本跑不出结果!!!
(二)input从控制台读入返回的是str型!!!
参见博客https://blog.youkuaiyun.com/qq_41151066/article/details/81745352
所以导致输入图片张数时出现错误:

TypeError: 'str' object cannot be interpreted as an integer

(三)然后又发现这个错误

NameError: name 'raw_input' is not defined

好家伙,参考文章https://blog.youkuaiyun.com/hochean_/article/details/79582627
raw_input改成 input
最终代码改成:

# mnist_app.py
# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_forward
import mnist_backward


def restore_model(testPicArr):
	# 重现计算图
	with tf.Graph().as_default() as tg:
		x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
		y = mnist_forward.forward(x, None)
		preValue = tf.argmax(y, 1) # y 的最大值对应的列表索引号

		# 实例化带有滑动平均值的saver
		variable_averages = tf.train.ExponentialMovingAverage(\
			mnist_backward.MOVING_AVERAGE_DECAY)
		variables_to_restore = variable_averages.variables_to_restore()
		saver = tf.train.Saver(variables_to_restore)

		# 用with结构加载ckpt
		with tf.Session() as sess:
			ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
			# 如果ckpt存在,恢复ckpt的参数和信息到当前会话
			if ckpt and ckpt.model_checkpoint_path:
				saver.restore(sess, ckpt.model_checkpoint_path)

				# 把刚刚准备好的图片喂入网络,执行预测操作
				preValue = sess.run(preValue, feed_dict = {x: testPicArr})
				return preValue
			else:
				print("No checkpoint file found!")
				return -1


def pre_pic(picName):
	# 打开图片
	img = Image.open(picName)
	# 用消除锯齿的方法resize图片尺寸
	reIm = img.resize((28, 28), Image.ANTIALIAS)
	# 转化成灰度图,并转化成矩阵
	im_arr = np.array(reIm.convert('L'))
	# 二值化阈值
	threshold = 50
	# 模型要求黑底白字,故需要进行反色
	for i in range(28):
		for j in range(28):
			im_arr[i][j] = 255 - im_arr[i][j]
			# 二值化,过滤噪声,留下主要特征
			if(im_arr[i][j] < threshold):
				im_arr[i][j] = 0
			else: im_arr[i][j] = 255
	# 整理矩阵形状
	nm_arr = im_arr.reshape([1, 784])
	# 由于模型要求是浮点数,先改为浮点型
	nm_arr = nm_arr.astype(np.float32)
	# 0到255浮点转化成0到1浮点
	img_ready = np.multiply(nm_arr, 1.0/255.0)

	# 返回预处理完的图片
	return img_ready

def application():
	# 输入要识别的图片数目 # input从控制台读入返回的是str型!!!
	testNum = int(input("Input the number of test pictures:") )
	for i in range(testNum):
		# 给出识别图片的路径 # raw_input从控制台读入字符串
		testPic = input("The path of test pictures:") 
		# 接收的图片需进行预处理
		testPicArr = pre_pic(testPic)
		# 把整理好的图片喂入神经网络
		preValue = restore_model(testPicArr)
		# 输出预测结果
		print("The prediction number is :", preValue)


# 程序从main函数开始执行
def main():
	# 调用application函数
	application()

if __name__ == '__main__':
	main()

and then______
我的人工智障程序识别我画的没有封口的0的结果是2,超难过:
在这里插入图片描述
又画了一张数字1:
在这里插入图片描述
是不是没有训练好,明天再试试。


【接上】2018.11.15
今天重新写了一个2,并且图片改成500*500像素的图片,而不是之前的长宽不一的,如下图:
在这里插入图片描述

然后识别结果就正确了:
在这里插入图片描述
于是接着写了剩下的9个数字,结果如下:
在这里插入图片描述
然后改了一下数字,最终,原谅我只能识别6和9之外的8个数字:
在这里插入图片描述


【接上】2018.11.18更新
https://github.com/cj0012/AI-Practice-Tensorflow-Notes/tree/master/pic 下载了图片,进行识别,结果全都可以识别出来:
在这里插入图片描述

【注】内容来自mooc人工智能实践第六讲

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值