本代码基于TF v1.4实现的,在V1的版本上应该能够顺利跑通。该示例无需手动下载MNIST数据集,TF API会自动下载并解压到指定文件夹(args.data_dir参数)。
一层ANN简单实现
#!/usr/bin/env python
# coding=utf-8
# author: dongzhou
import tensorflow as tf
import os
import numpy as np
import argparse
import shutil
from tensorflow.examples.tutorials.mnist import input_data
parser = argparse.ArgumentParser('MNIST Softmax')
parser.add_argument('--data_dir', type=str, default='/tmp/mnist-data',
help='the directory of MNIST dataset')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--model_path', type=str, default='', help='the path of checkpoint file')
args = parser.parse_args()
def model():
x = tf.placeholder(tf.float32, [None, 784], name='x')
w1 = tf.Variable(tf.zeros([784, 10]), name='weight1')
b1 = tf.Variable(tf.zeros([10]), name='bias1')
y = tf.matmul(x, w1) + b1
gt = tf.placeholder(tf.float32, [None, 10], name='groundtruth')
# losses
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=gt, logits=y))
# optimizer
optimizer = tf.train.GradientDescentOptimizer(args.lr)
# define one-step train ops
train_op = optimizer.minimize(cross_entropy)
return x, y, gt, train_op
if __name__ == "__main__":
max_train_step = 10000
mnist = input_data.read_data_sets(args.data_dir, one_hot=True)
x, y, gt, train_op = model()
# create saver
saver = tf.train.Saver(var_list=tf.global_variables())
if os.path.exists('./mnist'):
print('=> deleting old temporary directory ...')
shutil.rmtree('./mnist')
print('=> creating new temporary directory ...')
os.makedirs('./mnist')
else:
print('=> creating temporary directory ...')
os.makedirs('./mnist')
with tf.Session() as sess:
if args.model_path == '':
tf.global_variables_initializer().run()
else:
saver.restore(sess, args.model_path)
for i in range(max_train_step):
batch_x, batch_gt = mnist.train.next_batch(100)
sess.run(train_op, feed_dict={x: batch_x, gt: batch_gt})
if i % 100 == 0:
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(gt, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('=> accuracy: {}'.format(sess.run(accuracy, feed_dict={x: mnist.test.images, gt: mnist.test.labels})))
saver.save(sess, 'mnist/mnist_{:02d}.ckpt'.format(i + 1))
两层ANN网络实现
【问题】:
目前的代码仅是一层ANN的简单实现,即Y = W*X + b,经过10000步训练之后,模型的测试精度可以达到0.91左右。但是如果用两层ANN网络,即
Y
=
W
2
∗
(
W
1
∗
X
+
b
1
)
+
b
2
Y = W_2 *(W_1 * X +b_1) + b_2
Y=W2∗(W1∗X+b1)+b2,模型的测试精度极差(0.13左右),这是为什么?
两层ANN网络模型构建代码如下:
def model():
x = tf.placeholder(tf.float32, [None, 784], name='x')
w1 = tf.Variable(tf.zeros([784, 10]), name='weight1')
b1 = tf.Variable(tf.zeros([10]), name='bias1')
w2 = tf.Variable(tf.zeros([512, 10]), name='weight2')
b2 = tf.Variable(tf.zeros([10]), name='bias2')
b = tf.matmul(x, w1) + b1
y = tf.matmul(h, w2) + b2
gt = tf.placeholder(tf.float32, [None, 10], name='groundtruth')
# losses
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=gt, logits=y))
# optimizer
optimizer = tf.train.GradientDescentOptimizer(args.lr)
# define one-step train ops
train_op = optimizer.minimize(cross_entropy)
return x, y, gt, train_op
【答】:
经过排查,发现是因为所有layer的权重初始值都设为零,导致网络迭代到了局部极值,无法跳出。同时,我们引入了变量作用域,修改后的两层ANN网络模型如下:
def model():
x = tf.placeholder(tf.float32, [None, 784], name='x')
gt = tf.placeholder(tf.float32, [None, 10], name='groundtruth')
with tf.variable_scope('layer1'):
w1 = tf.get_variable('weight1', [784, 1024], initializer=tf.random_normal_initializer())
b1 = tf.get_variable('bias1', [1024], initializer=tf.constant_initializer(0.0))
h = tf.matmul(x, w1) + b1
with tf.variable_scope('layer2'):
w2 = tf.get_variable('weight2', [1024, 10], initializer=tf.random_normal_initializer())
b2 = tf.get_variable('bias2', [10], initializer=tf.constant_initializer(0.0))
y = tf.matmul(h, w2) + b2
# losses
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=gt, logits=y))
# optimizer
optimizer = tf.train.GradientDescentOptimizer(args.lr)
# define one-step train ops
train_op = optimizer.minimize(cross_entropy)
return x, y, gt, train_op
最终的测试精度能达到0.90左右。此外,如果进一步加深网络层数,对于模型精度没有很大的提升。如果要进一步优化可以替换激活函数,调整学习率等等。
三层ANN网络实现
三层ANN+ReLU激活函数,模型最终精度可以达到0.95,模型代码如下所示:
def model():
x = tf.placeholder(tf.float32, [None, 784], name='x')
gt = tf.placeholder(tf.float32, [None, 10], name='groundtruth')
with tf.variable_scope('layer1'):
w1 = tf.get_variable('weight1', [784, 1024], initializer=tf.random_normal_initializer())
b1 = tf.get_variable('bias1', [1024], initializer=tf.constant_initializer(0.0))
h1 = tf.nn.relu(tf.matmul(x, w1) + b1)
with tf.variable_scope('layer2'):
w2 = tf.get_variable('weight2', [1024, 1024], initializer=tf.random_normal_initializer())
b2 = tf.get_variable('bias2', [1024], initializer=tf.constant_initializer(0.0))
h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
with tf.variable_scope('layer3'):
w3 = tf.get_variable('weight3', [1024, 10], initializer=tf.random_normal_initializer())
b3 = tf.get_variable('bias3', [10], initializer=tf.constant_initializer(0.0))
y = tf.matmul(h2, w3) + b3
# losses
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=gt, logits=y))
# optimizer
optimizer = tf.train.GradientDescentOptimizer(args.lr)
# define one-step train ops
train_op = optimizer.minimize(cross_entropy)
return x, y, gt, train_op
# 部分输出
"""
=> accuracy: 0.9519000053405762
=> accuracy: 0.9519000053405762
=> accuracy: 0.9520000219345093
=> accuracy: 0.9520000219345093
=> accuracy: 0.9520000219345093
=> accuracy: 0.9520000219345093
=> accuracy: 0.9516000151634216
=> accuracy: 0.9513999819755554
=> accuracy: 0.9513999819755554
=> accuracy: 0.9513999819755554
=> accuracy: 0.9513999819755554
=> accuracy: 0.9513999819755554
=> accuracy: 0.9513999819755554
=> accuracy: 0.9513999819755554
=> accuracy: 0.9509999752044678
=> accuracy: 0.9509999752044678
=> accuracy: 0.9520000219345093
=> accuracy: 0.9509000182151794
=> accuracy: 0.9519000053405762
=> accuracy: 0.9513000249862671
=> accuracy: 0.9513999819755554
=> accuracy: 0.951200008392334
=> accuracy: 0.9514999985694885
=> accuracy: 0.9514999985694885
=> accuracy: 0.9502000212669373
=> accuracy: 0.9508000016212463
=> accuracy: 0.9506999850273132
=> accuracy: 0.9505000114440918
=> accuracy: 0.9505000114440918
=> accuracy: 0.9501000046730042
=> accuracy: 0.9509000182151794
=> accuracy: 0.9509000182151794
"""