模型持久化
为了使代码有更好的可读性和扩展性,需要将之按功能分为不同的模块,并将可重用的代码抽象成库函数
所以可以把以前臃肿的 MNIST 代码分成三个模块
- inference
- train
- eval
具体的文件夹目录如下
mnist/
data/
......
best/
inference.py
train.py
eval.py
完整代码
首先是 inference.py ,这个库函数负责模型训练及测试的前向传播过程
import tensorflow as tf
# 定义神经网络相关参数
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
# 创建权重变量,并加入正则化损失集合
def get_weight_variable(shape, regularizer):
weights = tf.get_variable(
'weights',
shape,
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None:
tf.add_to_collection('losses', regularizer(weights))
return weights
# 前向传播
def inference(input_tensor, regularizer):
# 声明隐藏层的变量并进行前向传播
with tf.variable_scope('layer1'):
weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable(
'biases', [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
# 声明输出层的变量并进行前向传播
with tf.variable_scope('layer2'):
weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
biases = tf.get_variable(
'biases', [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1, weights) + biases
return layer2
然后是 train.py ,训练模型的模块
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import inference
# 优化方法参数
LEARNING_RATE_BASE = 0.8 # 基础学习率
LEARNING_RATE_DECAY = 0.99 # 学习率的衰减率
REGULARIZATION_RATE = 0.0001 # 正则化项在损失函数中的系数
MOVING_AVERAGE_DECAY = 0.99 # 滑动平均衰减率
# 训练参数
BATCH_SIZE = 100 # 一个训练batch中的图片数
TRAINING_STEPS = 30000 # 训练轮数
# 模型保存的路径和文件名
MODEL_SAVE_PATH = 'model/'
MODEL_NAME = 'mnist.ckpt'
def train(mnist):
# 实现模型
x = tf.placeholder(
tf.float32, [None, inference.INPUT_NODE], name='x-input') # 输入层
y_ = tf.placeholder(
tf.float32, [None, inference.OUTPUT_NODE], name='y-input') # 标签
regularizer = tf.contrib.layers.l2_regularizer(
REGULARIZATION_RATE) # 定义L2正则化损失函数
y = inference.inference(x, regularizer) # 输出层
# 存储训练轮数,设置为不可训练
global_step = tf.Variable(0, trainable=False)
# 设置滑动平均方法
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY, global_step) # 定义滑动平均类
variable_averages_op = variable_averages.apply(
tf.trainable_variables()) # 在所有可训练的变量上使用滑动平均值
# 设置指数衰减法
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY)
# 最小化损失函数
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=y, labels=tf.argmax(y_, 1)) # 计算每张图片的交叉熵
cross_entropy_mean = tf.reduce_mean(cross_entropy) # 计算当前batch中所有图片的交叉熵平均值
loss = cross_entropy_mean + tf.add_n(
tf.get_collection('losses')) # 总损失等于交叉熵损失和正则化损失的和
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
loss, global_step=global_step) # 优化损失函数
# 同时反向传播和滑动平均
with tf.control_dependencies([train_step, variable_averages_op]):
train_op = tf.no_op(name='train')
# 初始化持久化类
saver = tf.train.Saver()
# 开始训练
with tf.Session() as sess:
# 初始化所有变量
tf.global_variables_initializer().run()
# 迭代训练
for i in range(TRAINING_STEPS):
# 产生该轮batch
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run(
[train_op, loss, global_step], feed_dict={
x: xs,
y_: ys
})
# 每1000轮保存一次模型
if i % 1000 == 0:
# 输出训练情况
print('After %d training steps, loss is %g.' % (step,
loss_value))
# 保存当前模型
saver.save(
sess,
os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
global_step=global_step)
# 主程序入口
def main(argv=None):
mnist = input_data.read_data_sets('../data/', one_hot=True)
train(mnist)
if __name__ == '__main__':
tf.app.run()
最后是 eval.py ,可以在训练模型的同时,每隔一段时间利用最新保存的模型进行测试
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import inference
import train
# 每10秒加载一次最新的模型,并在测试数据上测试最新模型的正确率
EVAL_INTERVAL_SECS = 10
def evaluate(mnist):
with tf.Graph().as_default() as g:
# 定义输入输出的格式
x = tf.placeholder(
tf.float32, [None, inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(
tf.float32, [None, inference.OUTPUT_NODE], name='y-input')
y = inference.inference(x, None)
# 验证集
validate_feed = {
x: mnist.validation.images,
y_: mnist.validation.labels
}
# 评估模型
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 通过变量重命名方式加载模型,获取滑动平均值
variable_averages = tf.train.ExponentialMovingAverage(
train.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
# 每隔10秒检测正确率
while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 加载模型
saver.restore(sess, ckpt.model_checkpoint_path)
# 通过文件名字获取该模型保存的轮数
global_step = ckpt.model_checkpoint_path.split('/')[
-1].split('-')[-1]
# 验证并输出结果
accuracy_score = sess.run(
accuracy, feed_dict=validate_feed)
print(
'After %s training steps, validattion accuracy = %g' %
(global_step, accuracy_score))
else:
print('No checkpoint file found')
return
time.sleep(EVAL_INTERVAL_SECS)
def main(argv=None):
mnist = input_data.read_data_sets('../data/', one_hot=True)
evaluate(mnist)
if __name__ == '__main__':
tf.app.run()
运行结果
train.py 训练模型的结果如下
$ python train.py
Extracting ../data/train-images-idx3-ubyte.gz
Extracting ../data/train-labels-idx1-ubyte.gz
Extracting ../data/t10k-images-idx3-ubyte.gz
Extracting ../data/t10k-labels-idx1-ubyte.gz
After 1 training steps, loss is 2.75381.
After 1001 training steps, loss is 0.26364.
After 2001 training steps, loss is 0.160792.
After 3001 training steps, loss is 0.144208.
After 4001 training steps, loss is 0.120926.
After 5001 training steps, loss is 0.10708.
After 6001 training steps, loss is 0.102106.
......
After 22001 training steps, loss is 0.0399828.
After 23001 training steps, loss is 0.0408827.
After 24001 training steps, loss is 0.0355409.
After 25001 training steps, loss is 0.0378072.
After 26001 training steps, loss is 0.0352473.
After 27001 training steps, loss is 0.0357247.
After 28001 training steps, loss is 0.0318179.
After 29001 training steps, loss is 0.0417907.
eval.py 评估模型的结果如下
$ python eval.py
Extracting ../data/train-images-idx3-ubyte.gz
Extracting ../data/train-labels-idx1-ubyte.gz
Extracting ../data/t10k-images-idx3-ubyte.gz
Extracting ../data/t10k-labels-idx1-ubyte.gz
After 26001 training steps, validattion accuracy = 0.983
After 28001 training steps, validattion accuracy = 0.985
After 29001 training steps, validattion accuracy = 0.986
......
转载: https://blog.youkuaiyun.com/white_idiot/article/details/78777022