一、断点续训
为防止突然断电、参数白跑的情况发生,在backward中加入类似于之前test中加载ckpt
的操作,给所有w和b赋保存在ckpt中的值:
1. 如果存储断点文件的目录文件夹中,包含有效断点状态文件,则返回该文件:
- 参数说明
checkpoint_dir
: 表示存储断点文件的目录
latest_filename
: 断点文件的可选名称,默认为checkpoint
ckpt = tf.train.get_checkpoint_state(checkpoint_dir,\
latest_filename = None)
2. 如果ckpt存在,且保存的模型在指定路径中存在
if ckpt and ckpt.model_checkpoint_path:
3. 恢复当前会话,将ckpt中的值赋给 w 和 b
- 参数说明
sess
:表示当前会话,之前保存的结果会被加载入这个会话
ckpt.model_checkpoint_path
:表示模型存储的位置,不需要提供模型的名字,因为有了位置会自动去查看checkpoint文件,看最新的模型叫什么
saver.restore(sess, ckpt.model_checkpoint_path)
4. 完整代码:
# 断点续训 breakpoint_continue.py
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 恢复当前会话,将ckpt中的值赋给 w 和 b
saver.restore(sess, ckpt.model_checkpoint_path)
在反向传播的with结构中加入加载ckpt的操作后:
# mnist_backward.py
# coding: utf-8
import tensorflow as tf
# 导入imput_data模块
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os
# 定义超参数
BATCH_SIZE = 200
LEARNING_RATE_BASE = 0.1 #初始学习率
LEARNING_RATE_DECAY = 0.99 # 学习率衰减率
REGULARIZER = 0.0001 # 正则化参数
STEPS = 50000 #训练轮数
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"
def backward(mnist):
# placeholder占位
x = tf.placeholder(tf.float32, shape = (None, mnist_forward.INPUT_NODE))
y_ = tf.placeholder(tf.float32, shape = (None, mnist_forward.OUTPUT_NODE))
# 前向传播推测输出y
y = mnist_forward.forward(x, REGULARIZER)
# 定义global_step轮数计数器,定义为不可训练
global_step = tf.Variable(0, trainable = False)
# 包含正则化的损失函数
# 交叉熵
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, \
labels = tf.argmax(y_, 1))
cem = tf.reduce_mean(ce)
# 使用正则化时的损失函数
loss = cem + tf.add_n(tf.get_collection('losses'))
# 定义指数衰减学习率
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples/BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase = True)
# 定义反向传播方法:包含正则化
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,\
global_step = global_step)
# 定义滑动平均时,加上:
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
train_op = tf.no_op(name = 'train')
# 实例化saver
saver = tf.train.Saver()
# 训练过程
with tf.Session() as sess:
# 初始化所有参数
init_op = tf.global_variables_initializer()
sess.run(init_op)
# 断点续训 breakpoint_continue.py
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 恢复当前会话,将ckpt中的值赋给 w 和 b
saver.restore(sess, ckpt.model_checkpoint_path)
# 循环迭代
for i in range(STEPS):
# 将训练集中一定batchsize的数据和标签赋给左边的变量
xs, ys = mnist.train.next_batch(BATCH_SIZE)
# 喂入神经网络,执行训练过程train_step
_, loss_value, step = sess.run([train_op, loss, global_step], \
feed_dict = {x: xs, y_: ys})
if i % 1000 == 0: # 拼接成./MODEL_SAVE_PATH/MODEL_NAME-global_step路径
# 打印提示
print("after %d steps, loss on traing batch is %g" %(step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), \
global_step = global_step)
def main():
mnist = input_data.read_data_sets('./data/', one_hot = True)
# 调用定义好的测试函数
backward(mnist)
# 判断python运行文件是否为主文件,如果是,则执行
if __name__ == '__main__':
main()
结果发现模型自动接着之前开机的结束的50000次开始往后训练了,在两次ctrl+C之后,再重新执行时仍可以从断点继续:
二、如何对输入的手写数字图片,输出正确预测结果
- 除了
minist_forward, mnist_backward, mnist_test
之外,增加mnist_app.py
一个py文件
自己遇到的问题之
(一)main函数没有写对
把
if __name__ == '__main__':
main()
写成了
if __name__ == 'main':
main()
结果代码根本跑不出结果!!!
(二)input从控制台读入返回的是str型!!!
参见博客https://blog.youkuaiyun.com/qq_41151066/article/details/81745352
所以导致输入图片张数时出现错误:
TypeError: 'str' object cannot be interpreted as an integer
(三)然后又发现这个错误
NameError: name 'raw_input' is not defined
好家伙,参考文章https://blog.youkuaiyun.com/hochean_/article/details/79582627
把raw_input
改成 input
最终代码改成:
# mnist_app.py
# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_forward
import mnist_backward
def restore_model(testPicArr):
# 重现计算图
with tf.Graph().as_default() as tg:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y = mnist_forward.forward(x, None)
preValue = tf.argmax(y, 1) # y 的最大值对应的列表索引号
# 实例化带有滑动平均值的saver
variable_averages = tf.train.ExponentialMovingAverage(\
mnist_backward.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
# 用with结构加载ckpt
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
# 如果ckpt存在,恢复ckpt的参数和信息到当前会话
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# 把刚刚准备好的图片喂入网络,执行预测操作
preValue = sess.run(preValue, feed_dict = {x: testPicArr})
return preValue
else:
print("No checkpoint file found!")
return -1
def pre_pic(picName):
# 打开图片
img = Image.open(picName)
# 用消除锯齿的方法resize图片尺寸
reIm = img.resize((28, 28), Image.ANTIALIAS)
# 转化成灰度图,并转化成矩阵
im_arr = np.array(reIm.convert('L'))
# 二值化阈值
threshold = 50
# 模型要求黑底白字,故需要进行反色
for i in range(28):
for j in range(28):
im_arr[i][j] = 255 - im_arr[i][j]
# 二值化,过滤噪声,留下主要特征
if(im_arr[i][j] < threshold):
im_arr[i][j] = 0
else: im_arr[i][j] = 255
# 整理矩阵形状
nm_arr = im_arr.reshape([1, 784])
# 由于模型要求是浮点数,先改为浮点型
nm_arr = nm_arr.astype(np.float32)
# 0到255浮点转化成0到1浮点
img_ready = np.multiply(nm_arr, 1.0/255.0)
# 返回预处理完的图片
return img_ready
def application():
# 输入要识别的图片数目 # input从控制台读入返回的是str型!!!
testNum = int(input("Input the number of test pictures:") )
for i in range(testNum):
# 给出识别图片的路径 # raw_input从控制台读入字符串
testPic = input("The path of test pictures:")
# 接收的图片需进行预处理
testPicArr = pre_pic(testPic)
# 把整理好的图片喂入神经网络
preValue = restore_model(testPicArr)
# 输出预测结果
print("The prediction number is :", preValue)
# 程序从main函数开始执行
def main():
# 调用application函数
application()
if __name__ == '__main__':
main()
and then______
我的人工智障程序识别我画的没有封口的0的结果是2,超难过:
又画了一张数字1:
是不是没有训练好,明天再试试。
【接上】2018.11.15
今天重新写了一个2,并且图片改成500*500像素的图片,而不是之前的长宽不一的,如下图:
然后识别结果就正确了:
于是接着写了剩下的9个数字,结果如下:
然后改了一下数字,最终,原谅我只能识别6和9之外的8个数字:
【接上】2018.11.18更新
从 https://github.com/cj0012/AI-Practice-Tensorflow-Notes/tree/master/pic 下载了图片,进行识别,结果全都可以识别出来:
【注】内容来自mooc人工智能实践第六讲