从 tensorflow 提供的 cifar10 分类程序样例学习程序的主要框架
最外层结构:
- 利用 tf.app.flags 解析命令行参数
- 利用 tf.app.run 启动主函数,在主函数中训练
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 100000,
"""Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
"""How often to log results to the console.""")
def train():
pass
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()
if __name__ == '__main__':
tf.app.run()
下面来看训练过程:搭建计算流图,计算 loss,选择合适的优化器以减小 loss 为目标来更新权重,周期性的记录日志。
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
# 强制使用 CPU
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs()
# 这里的 logits 是计算图中的计算 softmax 之前的节点
logits = cifar10.inference(images)
# 这里的 loss 是计算 loss 的节点
loss = cifar10.loss(logits, labels)
# 这个是利用反向传播更新网络权值的计算节点
train_op = cifar10.train(loss, global_step)
# 这里通过重载 SessionRunHook 这个运行时的回调函数,实现周期写日志的功能
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
# 一系列回调函数
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), #达到最大步长限制
tf.train.NanTensorHook(loss), # loss 发散
_LoggerHook()], # 周期写日志
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
上面的代码中使用回调函数来写日志,看起来比较高级。对于初学者来说,可能更习惯于把 log 的代码直接加在循环里,向下面这样,就不用多写一个内部类_LoggerHook:
def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
# 强制使用 CPU
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs()
# 这里的 logits 是计算图中的计算 softmax 之前的节点
logits = cifar10.inference(images)
# 这里的 loss 是计算 loss 的节点
loss = cifar10.loss(logits, labels)
# 这个是利用反向传播更新网络权值的计算节点
train_op = cifar10.train(loss, global_step)
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss)], # 这里把写日志的回调函数删除了
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
step = -1
start_time = time.time()
while not mon_sess.should_stop():
step += 1
lss,_ = mon_sess.run([loss,train_op])
# 直接在 while 循环里写日志, 效果和回调函数是一样的
if step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - start_time
start_time = current_time
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), step, lss,
examples_per_sec, sec_per_batch))
记录的日志如下:
2019-03-19 19:04:37.806384: step 0, loss = 4.67 (260.6 examples/sec; 0.491 sec/batch)
2019-03-19 19:04:42.597658: step 10, loss = 4.62 (267.2 examples/sec; 0.479 sec/batch)
2019-03-19 19:04:47.308927: step 20, loss = 4.49 (271.7 examples/sec; 0.471 sec/batch)
...
最后深入到计算流图的细节,看看每一个节点是怎么搭建的(cifar10.py):
(源代码中关于summary,use_fp16 等细枝末节已被剔除,不影响主干程序)
"""Builds the CIFAR-10 network.
Summary of available functions:
# Compute input images and labels for training. If you would like to run
# evaluations, use inputs() instead.
inputs, labels = distorted_inputs()
# Compute inference on the model inputs to make a prediction.
predictions = inference(inputs)
# Compute the total loss of the prediction with respect to the labels.
loss = loss(predictions, labels)
# Create a graph to run one step of training with respect to the loss.
train_op = train(loss, global_step)
"""
import os
import re
import sys
import tarfile
from six.moves import urllib
import tensorflow as tf
import cifar10_input
FLAGS = tf.app.flags.FLAGS
# Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
# Global constants describing the CIFAR-10 data set.
IMAGE_SIZE = 24
NUM_CLASSES = 10
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
# Constants describing the training process.
MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
# If a model is trained with multiple GPUs, prefix all Op names with tower_name
# to differentiate the operations. Note that this prefix is removed from the
# names of the summaries when visualizing a model.
TOWER_NAME = 'tower'
DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
def _variable_on_cpu(name, shape, initializer):
"""Helper to create a Variable stored on CPU memory.
Args:
name: name of the variable
shape: list of ints
initializer: initializer for Variable
Returns:
Variable Tensor
"""
with tf.device('/cpu:0'):
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
return var
def _variable_with_weight_decay(name, shape, stddev, wd):
"""Helper to create an initialized Variable with weight decay.
Note that the Variable is initialized with a truncated normal distribution.
A weight decay is added only if one is specified.
Args:
name: name of the variable
shape: list of ints
stddev: standard deviation of a truncated Gaussian
wd: add L2Loss weight decay multiplied by this float. If None, weight
decay is not added for this Variable.
Returns:
Variable Tensor
"""
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
var = _variable_on_cpu(
name,
shape,
tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
if wd is not None:
weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
return var
def distorted_inputs():
"""Construct distorted input for CIFAR training using the Reader ops.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
pass
return images, labels
def inputs(eval_data):
"""Construct input for CIFAR evaluation using the Reader ops.
Args:
eval_data: bool, indicating if one should use the train or eval data set.
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
Raises:
ValueError: If no data_dir
"""
pass
return images, labels
# CNN 的 核心结构
def inference(images):
"""Build the CIFAR-10 model.
Args:
images: Images returned from distorted_inputs() or inputs().
Returns:
Logits.
"""
# We instantiate all variables using tf.get_variable() instead of
# tf.Variable() in order to share variables across multiple GPU training runs.
# If we only ran this model on a single GPU, we could simplify this function
# by replacing all instances of tf.get_variable() with tf.Variable().
#
# conv1
with tf.variable_scope('conv1') as scope:
kernel = _variable_with_weight_decay('weights',
shape=[5, 5, 3, 64],
stddev=5e-2,
wd=None)
conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
pre_activation = tf.nn.bias_add(conv, biases)
conv1 = tf.nn.relu(pre_activation, name=scope.name)
# pool1
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
padding='SAME', name='pool1')
# norm1
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
name='norm1')
# conv2
with tf.variable_scope('conv2') as scope:
kernel = _variable_with_weight_decay('weights',
shape=[5, 5, 64, 64],
stddev=5e-2,
wd=None)
conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
pre_activation = tf.nn.bias_add(conv, biases)
conv2 = tf.nn.relu(pre_activation, name=scope.name)
# norm2
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
name='norm2')
# pool2
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1], padding='SAME', name='pool2')
# local3
with tf.variable_scope('local3') as scope:
# Move everything into depth so we can perform a single matrix multiply.
reshape = tf.reshape(pool2, [images.get_shape().as_list()[0], -1])
dim = reshape.get_shape()[1].value
weights = _variable_with_weight_decay('weights', shape=[dim, 384],
stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
# local4
with tf.variable_scope('local4') as scope:
weights = _variable_with_weight_decay('weights', shape=[384, 192],
stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)
# linear layer(WX + b),
# We don't apply softmax here because
# tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
# and performs the softmax internally for efficiency.
with tf.variable_scope('softmax_linear') as scope:
weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
stddev=1/192.0, wd=None)
biases = _variable_on_cpu('biases', [NUM_CLASSES],
tf.constant_initializer(0.0))
softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name)
return softmax_linear
def loss(logits, labels):
"""Add L2Loss to all the trainable variables.
Add summary for "Loss" and "Loss/avg".
Args:
logits: Logits from inference().
labels: Labels from distorted_inputs or inputs(). 1-D tensor
of shape [batch_size]
Returns:
Loss tensor of type float.
"""
# Calculate the average cross entropy loss across the batch.
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
return tf.add_n(tf.get_collection('losses'), name='total_loss')
def _add_loss_summaries(total_loss):
"""Add summaries for losses in CIFAR-10 model.
Generates moving average for all losses and associated summaries for
visualizing the performance of the network.
Args:
total_loss: Total loss from loss().
Returns:
loss_averages_op: op for generating moving averages of losses.
"""
# Compute the moving average of all individual losses and the total loss.
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
losses = tf.get_collection('losses')
loss_averages_op = loss_averages.apply(losses + [total_loss])
return loss_averages_op
def train(total_loss, global_step):
"""Train CIFAR-10 model.
Create an optimizer and apply to all trainable variables. Add moving
average for all trainable variables.
Args:
total_loss: Total loss from loss().
global_step: Integer Variable counting the number of training steps
processed.
Returns:
train_op: op for training.
"""
# Variables that affect learning rate.
num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
# Decay the learning rate exponentially based on the number of steps.
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
global_step,
decay_steps,
LEARNING_RATE_DECAY_FACTOR,
staircase=True)
# Generate moving averages of all losses and associated summaries.
loss_averages_op = _add_loss_summaries(total_loss)
# Compute gradients.
with tf.control_dependencies([loss_averages_op]):
opt = tf.train.GradientDescentOptimizer(lr)
grads = opt.compute_gradients(total_loss)
# Apply gradients.
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
# Track the moving averages of all trainable variables.
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY, global_step)
with tf.control_dependencies([apply_gradient_op]):
variables_averages_op = variable_averages.apply(tf.trainable_variables())
return variables_averages_op
def maybe_download_and_extract():
"""Download and extract the tarball from Alex's website."""
pass