TensorFlow Mechanics 101解读
https://www.tensorflow.org/get_started/mnist/mechanics
TensorFlow安装后自带教材,里面包括了8个python文件。
init.py
fully_connected_feed.py
input_data.py
mnist_deep.py
mnist_softmax_xla.py
mnist_softmax.py
mnist_with_summaries.py
mnist.py
init.py没啥说的,一些导入的模块。
mnist_deep.py前文Experts时已讨论过了。
mnist_softmax.py在更早前Beginner时也已说过了。
今天学习mnist.py和fully_connected_feed.py,学习简单的前馈神经网络。
mnist.py构造了一个全连接模型,fully_connected_feed.py训练这个模型。
在eclipse中先跑下fully_connected_feed.py看看输出。
输出:
Extracting /tmp/tensorflow/mnist/input_data\train-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data\train-labels-idx1-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data\t10k-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data\t10k-labels-idx1-ubyte.gz
Step 0: loss = 2.31 (1.344 sec)
Step 100: loss = 2.11 (0.003 sec)
Step 200: loss = 1.77 (0.004 sec)
Step 300: loss = 1.49 (0.004 sec)
Step 400: loss = 1.06 (0.004 sec)
Step 500: loss = 1.06 (0.003 sec)
Step 600: loss = 0.81 (0.004 sec)
Step 700: loss = 0.70 (0.003 sec)
Step 800: loss = 0.63 (0.004 sec)
Step 900: loss = 0.56 (0.004 sec)
Training Data Eval:
Num examples: 55000 Num correct: 47436 Precision @ 1: 0.8625
Validation Data Eval:
Num examples: 5000 Num correct: 4350 Precision @ 1: 0.8700
Test Data Eval:
Num examples: 10000 Num correct: 8642 Precision @ 1: 0.8642
Step 1000: loss = 0.65 (0.021 sec)
Step 1100: loss = 0.52 (0.113 sec)
Step 1200: loss = 0.37 (0.004 sec)
Step 1300: loss = 0.41 (0.003 sec)
Step 1400: loss = 0.39 (0.003 sec)
Step 1500: loss = 0.41 (0.004 sec)
Step 1600: loss = 0.47 (0.004 sec)
Step 1700: loss = 0.47 (0.004 sec)
Step 1800: loss = 0.33 (0.004 sec)
Step 1900: loss = 0.42 (0.003 sec)
Training Data Eval:
Num examples: 55000 Num correct: 49407 Precision @ 1: 0.8983
Validation Data Eval:
Num examples: 5000 Num correct: 4511 Precision @ 1: 0.9022
Test Data Eval:
Num examples: 10000 Num correct: 9028 Precision @ 1: 0.9028
先不管它,确认可以正常运行。
看到mnist的文件名,不用说还是手写字识别问题。
一行一行看吧。先看fully_connected_feed.py
首先看到的是:
parser = argparse.ArgumentParser()
argparse,看名字是参数解析模块。ArgumentParser是里面的一个类,用于将命令行对象转为python对象。
看下说明中的一个例子。
import argparse;
import sys
parser = argparse.ArgumentParser(
description='sum the integers at the command line')
parser.add_argument(
'integers', metavar='int', nargs='+', type=int,
help='an integer to be summed')
parser.add_argument(
'--log', default=sys.stdout, type=argparse.FileType('w'),
help='the file where the sum should be written')
args = parser.parse_args()
print(args.integers) #输出[3, 6, 5, 4]
args.log.write('%s' % sum(args.integers))
args.log.close()
在eclipse IDE的Run Configurations的Arguments中配置Program arguments参数:3 6 5 4
运行后输出:
[3, 6, 5, 4]
18
再来一个和代码更象的(这次不需要设命令行参数):
parser = argparse.ArgumentParser(
description='设CLI参数,然后取出。')
parser.add_argument(
'--learning_rate',
type=float,
default=0.321,
help='Initial learning rate.'
)
args = parser.parse_args()
print(args.learning_rate)
输出:0.321
所以,代码226-274就简单了,给了一些命令行参数而已,整理下:
learning_rate 0.01
max_steps 2000
hidden1 128
hidden2 32
batch_size 100
input_data_dir /tmp/tensorflow/mnist/input_data
log_dir /tmp/tensorflow/mnist/logs/fully_connected_feed
fake_data false
再看下面代码:
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
FLAGS就是把上面设的命令行包在一个名字空间中,unparsed正常情况下是一个空的list,sys.argv[0]就是指fully_connected_feed.py自身。
举例:
自定义文件parser.py:
import argparse;
import sys
parser = argparse.ArgumentParser(
description='设CLI参数,然后取出。')
parser.add_argument(
'--learning_rate',
type=float,
default=0.321,
help='Initial learning rate.'
)
args = parser.parse_args()
FLAGS, unparsed = parser.parse_known_args()
print(FLAGS) #输出Namespace(learning_rate=0.321)
print(unparsed) #输出[]
print(sys.argv[0]) #输出C:\Users\hasee\workspace\tftest\com\101\parser.py
进入main方法后,重建日志目录,运行run_training(),这个简单。后面开始步入正题。
data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)
读入训练数据,对比前面的调法:
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
设置one_hot=True。这里用不上,fake_data本身就默认false,给不给都一样。总之,拿到数据了。
with tf.Graph().as_default():
Graph是tensorflow.python.framework.ops模块的一个类,表示一个数据流图。一个Graph包括一套操作(或者说计算),和一套对象(或者说数据)。图有很多,但只有一个图是当前缺省的图,as_default()指定这个图为当前上下文的缺省图。它不是线程安全的,所有操作只能从一个线程创建或使用线程同步。
images_placeholder, labels_placeholder = placeholder_inputs(
FLAGS.batch_size)
占位符技术,以前学过。
下面的高潮的部分:
# Build a Graph that computes predictions from the inference model.
logits = mnist.inference(images_placeholder,
FLAGS.hidden1,
FLAGS.hidden2)
# Add to the Graph the Ops for loss calculation.
loss = mnist.loss(logits, labels_placeholder)
# Add to the Graph the Ops that calculate and apply gradients.
train_op = mnist.training(loss, FLAGS.learning_rate)
inference定义了模型:使用了从图片输入开始,使用了两个隐藏层和一个softmax层(10类),全连接方式,代码本身容易懂。
name_scope:在定义一个Op时,给出的上下文管理器。所以,每层都在自己的名字空间下运行,不会冲突。如第一隐藏层的权重为hidden1/weights。
loss:对比模型和标签,定义损失函数。使用平均交叉熵,这个在‘Deep MNIST for Experts解读(一)’中说过。
training:使用常见的坡度下降优化器最小化损失,跟踪损失和全局步数,没有什么特别。
TensorBoard待稍后讲可视化的时候再看,现在不管,但以下三句与可视化相关:
summary = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
summary_writer.add_summary(summary_str, step)
保存检查点:
saver = tf.train.Saver()
saver.save(sess, FLAGS.train_dir, global_step=step)
恢复到检查点(本代码不涉及恢复):
saver.restore(sess, FLAGS.train_dir)
如开始跑fully_connected_feed.py的结果:
Test Data Eval:
Num examples: 10000 Num correct: 9028 Precision @ 1: 0.9028
90.28%的正确率不算太高。
本章看到了一个全连接模型,简单。但以后回过头来可以看看tensorboard,checkpoint等的用法。
https://www.tensorflow.org/get_started/mnist/mechanics
TensorFlow安装后自带教材,里面包括了8个python文件。
init.py
fully_connected_feed.py
input_data.py
mnist_deep.py
mnist_softmax_xla.py
mnist_softmax.py
mnist_with_summaries.py
mnist.py
init.py没啥说的,一些导入的模块。
mnist_deep.py前文Experts时已讨论过了。
mnist_softmax.py在更早前Beginner时也已说过了。
今天学习mnist.py和fully_connected_feed.py,学习简单的前馈神经网络。
mnist.py构造了一个全连接模型,fully_connected_feed.py训练这个模型。
在eclipse中先跑下fully_connected_feed.py看看输出。
输出:
Extracting /tmp/tensorflow/mnist/input_data\train-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data\train-labels-idx1-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data\t10k-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data\t10k-labels-idx1-ubyte.gz
Step 0: loss = 2.31 (1.344 sec)
Step 100: loss = 2.11 (0.003 sec)
Step 200: loss = 1.77 (0.004 sec)
Step 300: loss = 1.49 (0.004 sec)
Step 400: loss = 1.06 (0.004 sec)
Step 500: loss = 1.06 (0.003 sec)
Step 600: loss = 0.81 (0.004 sec)
Step 700: loss = 0.70 (0.003 sec)
Step 800: loss = 0.63 (0.004 sec)
Step 900: loss = 0.56 (0.004 sec)
Training Data Eval:
Num examples: 55000 Num correct: 47436 Precision @ 1: 0.8625
Validation Data Eval:
Num examples: 5000 Num correct: 4350 Precision @ 1: 0.8700
Test Data Eval:
Num examples: 10000 Num correct: 8642 Precision @ 1: 0.8642
Step 1000: loss = 0.65 (0.021 sec)
Step 1100: loss = 0.52 (0.113 sec)
Step 1200: loss = 0.37 (0.004 sec)
Step 1300: loss = 0.41 (0.003 sec)
Step 1400: loss = 0.39 (0.003 sec)
Step 1500: loss = 0.41 (0.004 sec)
Step 1600: loss = 0.47 (0.004 sec)
Step 1700: loss = 0.47 (0.004 sec)
Step 1800: loss = 0.33 (0.004 sec)
Step 1900: loss = 0.42 (0.003 sec)
Training Data Eval:
Num examples: 55000 Num correct: 49407 Precision @ 1: 0.8983
Validation Data Eval:
Num examples: 5000 Num correct: 4511 Precision @ 1: 0.9022
Test Data Eval:
Num examples: 10000 Num correct: 9028 Precision @ 1: 0.9028
先不管它,确认可以正常运行。
看到mnist的文件名,不用说还是手写字识别问题。
一行一行看吧。先看fully_connected_feed.py
首先看到的是:
parser = argparse.ArgumentParser()
argparse,看名字是参数解析模块。ArgumentParser是里面的一个类,用于将命令行对象转为python对象。
看下说明中的一个例子。
import argparse;
import sys
parser = argparse.ArgumentParser(
description='sum the integers at the command line')
parser.add_argument(
'integers', metavar='int', nargs='+', type=int,
help='an integer to be summed')
parser.add_argument(
'--log', default=sys.stdout, type=argparse.FileType('w'),
help='the file where the sum should be written')
args = parser.parse_args()
print(args.integers) #输出[3, 6, 5, 4]
args.log.write('%s' % sum(args.integers))
args.log.close()
在eclipse IDE的Run Configurations的Arguments中配置Program arguments参数:3 6 5 4
运行后输出:
[3, 6, 5, 4]
18
再来一个和代码更象的(这次不需要设命令行参数):
parser = argparse.ArgumentParser(
description='设CLI参数,然后取出。')
parser.add_argument(
'--learning_rate',
type=float,
default=0.321,
help='Initial learning rate.'
)
args = parser.parse_args()
print(args.learning_rate)
输出:0.321
所以,代码226-274就简单了,给了一些命令行参数而已,整理下:
learning_rate 0.01
max_steps 2000
hidden1 128
hidden2 32
batch_size 100
input_data_dir /tmp/tensorflow/mnist/input_data
log_dir /tmp/tensorflow/mnist/logs/fully_connected_feed
fake_data false
再看下面代码:
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
FLAGS就是把上面设的命令行包在一个名字空间中,unparsed正常情况下是一个空的list,sys.argv[0]就是指fully_connected_feed.py自身。
举例:
自定义文件parser.py:
import argparse;
import sys
parser = argparse.ArgumentParser(
description='设CLI参数,然后取出。')
parser.add_argument(
'--learning_rate',
type=float,
default=0.321,
help='Initial learning rate.'
)
args = parser.parse_args()
FLAGS, unparsed = parser.parse_known_args()
print(FLAGS) #输出Namespace(learning_rate=0.321)
print(unparsed) #输出[]
print(sys.argv[0]) #输出C:\Users\hasee\workspace\tftest\com\101\parser.py
进入main方法后,重建日志目录,运行run_training(),这个简单。后面开始步入正题。
data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)
读入训练数据,对比前面的调法:
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
设置one_hot=True。这里用不上,fake_data本身就默认false,给不给都一样。总之,拿到数据了。
with tf.Graph().as_default():
Graph是tensorflow.python.framework.ops模块的一个类,表示一个数据流图。一个Graph包括一套操作(或者说计算),和一套对象(或者说数据)。图有很多,但只有一个图是当前缺省的图,as_default()指定这个图为当前上下文的缺省图。它不是线程安全的,所有操作只能从一个线程创建或使用线程同步。
images_placeholder, labels_placeholder = placeholder_inputs(
FLAGS.batch_size)
占位符技术,以前学过。
下面的高潮的部分:
# Build a Graph that computes predictions from the inference model.
logits = mnist.inference(images_placeholder,
FLAGS.hidden1,
FLAGS.hidden2)
# Add to the Graph the Ops for loss calculation.
loss = mnist.loss(logits, labels_placeholder)
# Add to the Graph the Ops that calculate and apply gradients.
train_op = mnist.training(loss, FLAGS.learning_rate)
inference定义了模型:使用了从图片输入开始,使用了两个隐藏层和一个softmax层(10类),全连接方式,代码本身容易懂。
name_scope:在定义一个Op时,给出的上下文管理器。所以,每层都在自己的名字空间下运行,不会冲突。如第一隐藏层的权重为hidden1/weights。
loss:对比模型和标签,定义损失函数。使用平均交叉熵,这个在‘Deep MNIST for Experts解读(一)’中说过。
training:使用常见的坡度下降优化器最小化损失,跟踪损失和全局步数,没有什么特别。
TensorBoard待稍后讲可视化的时候再看,现在不管,但以下三句与可视化相关:
summary = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
summary_writer.add_summary(summary_str, step)
保存检查点:
saver = tf.train.Saver()
saver.save(sess, FLAGS.train_dir, global_step=step)
恢复到检查点(本代码不涉及恢复):
saver.restore(sess, FLAGS.train_dir)
如开始跑fully_connected_feed.py的结果:
Test Data Eval:
Num examples: 10000 Num correct: 9028 Precision @ 1: 0.9028
90.28%的正确率不算太高。
本章看到了一个全连接模型,简单。但以后回过头来可以看看tensorboard,checkpoint等的用法。