首先是mnist_train.py的修改如下:
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
# 加载mnist_inference.py中定义的常量和前向传播的函数
import mnist_inference
# 配置神经网络的参数
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARAZTION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99
# 模型保存的路径和文件名
MODEL_SAVE_PATH = "/path/to/model/"
MODEL_NAME= "model.ckpt"
def train(mnist):
# 定义输入输出placehoder.
x = tf.placeholder(
tf.float32, [BATCH_SIZE,mnist_inference.IMAGE_SIZE,
mnist_inference.IMAGE_SIZE,
mnist_inference.NUM_CHANNELS],name='x-input'
)
y_ = tf.placeholder(
tf.float32,[None,mnist_inference.OUTPUT_NODE],name='y-input'
)
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
# 直接使用mnist_inference.py中定义的前向传播过程
y = mnist_inference.inference(x, train, regularizer)
global_step = tf.Variable(0,trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY, global_step
)
variable_averages_op = variable_averages.apply(
tf.trainable_variables()
)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=y, labels=tf.argmax(y_, 1)
)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss=cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,