使用TensorFlow做CIFAR-10 Dataset的图片分类(全代码及解析)(适合初学者)

本文详细介绍了使用TensorFlow进行CIFAR-10图像分类的全过程,包括卷积神经网络的核心部分、正则化的概念及其在避免过拟合中的作用。代码涵盖cifar10_train.py(单GPU版本)和cifar10.py,讨论了模型训练、数据处理、损失函数以及L1和L2正则化的应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

实验之前我们必须声明一点,普通的全连接 BP 网络跑一下 CIFAR-10 也是 有比较好的结果的,只是 CNN会有更好的收敛速度和更高一些的精确度。 这里说的 CIFAR-10 是由 Geoffrey Hinton 和他的两个学生 Alex Krizhevsky、 Ilya Sutskever 所收集的一个用于普适物体识别的数据集,也叫作 CIFAR-10 Dataset。 CIFAR 是加拿大政府牵头投资的一个先进科学项目研究所,全称是 Cooperative Institute for Arctic Research, 项目主页位于: https://www.cs.toronto.edu/~kriz/cifar.html。 虽然看起来样子非常简阻,但 是实验用的基本信息一应俱全。 这个项目中包含了 60000 张 32 × 32 像素的彩色图片,拥有 10 个不同类别的标签。 其中 50000 张是训练集,还有 10000 张是测试集当作实验玩具来说应该是绰绰有余了。包括前面我们提到的 MNIST 数据集在内,它们都是由政府或者大的非盈利组织提供出来供初 学者学习或交流所用的,毕竟带有高质量标签的样本是在深度学习中是成本最高的东西了。

TensorFlow 官方同样提供了 GitHub 地址供大家下载 CNN 做 CIFARlO 实验的代码,
位置在: https://github.com/tensorflow/models
文件所在目录 https://github.com/tensorflow/models/tutorials/image/cifar10

卷积神经网络最重要的核心部分就是卷积核。 在很多复杂的应用中,对于同一个输入的向量由不同尺寸的卷积核扫描,产生不同的特征描述 Feature Map 输入到后端,也可能在 不同的层用不同尺寸的卷积核去提取特征。卷积神经网络的特征就是有卷积层,进而带来的好处就是收敛速度比较快并且泛化能力会显得比较好。 卷积核的优良特性使得它在很多网络中都有使用,它可能会由于模型上 的需要仅仅出现在一个网络中的某几层的位置,也可能会在一个模型的多层中出现,总而 言之应用起来还是非常灵活的。 而是不是用了卷积核的网络都会被称作卷积神经网络这个 倒是未必至少很多网络由于其中应用了许多其他的结构而使得网络体现出来很多更为独特的特性的时候,命名时会更倾向使用标新立异的方式,后面我们看看深度残差网络就知道了。


关于正则化

正则化这一过程就是帮助我们找到更为简洁的描述方式的量化过程


在机器学习中,我们是通过大量样本放入模型中训练得到待定系数的,而不论是哪种 模型,其实我们都希望这种模型在精确的前提下尽可能简洁。 请注意,这里说的精确可不 是说在测试集上精确就够了,是指其泛化能力要好,也就是说在验证集以及其他测试集上 同样要有好的表现。


对于观察到的 各种认知对象来说,描述共性的东西越抽象、越简洁,其泛化性也就越好;相反,越是精 确描述个体的东西,通常“个性化”的特点就非常明显,越具体、越复杂,泛化性也就越差。


例如,我们在描述一个事物( object),说这个东西是“方的”,那么通常是指这个物体 的投影外形是有四条边组成,其中两两平行,并且两两垂直。 “方的”这个词的描述就非常 简洁,而描述的内容则是忽略掉大小、材质、重量、颜色等诸多性状的。 而如果叙述“正 方形手帕”这样一个词汇,描述的内容就变多了,你也可以认为参数变多了,而这个时候 其实就有了约束性,从而降低了泛化性。 因为一旦你说“正方形手帕”,那么这个物体首先 材质应该是类似于棉布、纱绸、麻丝一类的织物,其他材质显然和它对不上号,正方形则 表示性状的约束更为严格,起码给人的感觉四条边的长短不会有明显的不同。 这些就是你 加入更多描述之后产生的限制和泛化性缩小的过程。 如果你再叙述一个“昨天从淘宝买的 由苏州发货的白色双面绣正方形蕾丝手帕”,这个描述更为具体,当然参数也就更多,但是 泛化性明显是刚才这几个语汇中最低的一个,你能用它来指代的事物就最少。
而正则化这一过程就是帮助我们找到更为简洁的描述方式的量化过程


这里写图片描述

加上这个部分能够在一定程度上避免过拟合的原理,现在就来具体讲。 从学术上来讲,前半部分的损失函数叫作“经验风险”,后半部分的损失函数(也就是加人的正则 化项的部分)叫作“结构风险” 。 所谓“经验风险”就是指那些由于拟合结果和样本标签之 间的残差总和所产生的这种经验性差距所带来的风险一一毕竟差距越大模型拟合失效的可 能性也就越大,这当然是风险,欠拟合的风险;“结构风险”就是我们刚刚提到的那种概念 了,我们希望这种描述能够简洁来保证其泛化性的良好,所以加入一个因子L2正则项。


L1因子的含义就是把整个模型中所有的权重 w 的绝对值加起来除以样本数量。 其中 A 不是我们说的学习率(虽然有的资料上会用 λ 做学习率的符号表示),而是一个权重一一也 可以称为正则化系数或惩罚系数,表示对这个部分有多‘重视” 。 如果我们很重视结构风险, 或者说很不希望结构风险太大,那我们就加大 λ,迫使整个损失函数向着权值 w 减小的方 向快速移动。 换句话说, w 的值越多、越大,整个因子的值就越大,也就是越不“简洁” 。


而在 L2 正则项中求导后正好可以消掉一分母中的 2,计算起来要方便 一些,这也是在构造这种因子的时候特别设计的 Trick


cifar10_train.py(单GPU版本)

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A binary to train CIFAR-10 using a single GPU.

Accuracy:
cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
data) as judged by cifar10_eval.py.

Speed: With batch_size 128.

System        | Step Time (sec/batch)  |     Accuracy
------------------------------------------------------------------
1 Tesla K20m  | 0.35-0.60              | ~86% at 60K steps  (5 hours)
1 Tesla K40m  | 0.25-0.35              | ~86% at 100K steps (4 hours)

Usage:
Please see the tutorial and website for how to download the CIFAR-10
data set, compile the program and train the model.

http://tensorflow.org/tutorials/deep_cnn/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import time

import tensorflow as tf

import cifar10

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 100000,
                            """Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
                            """How often to log results to the console.""")


def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
      images, labels = cifar10.distorted_inputs()

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)


def main(argv=None):  # pylint: disable=unused-argument
  cifar10.maybe_download_and_extract()
  if tf.gfile.Exists(FLAGS.train_dir):
    tf.gfile.DeleteRecursively(FLAGS.train_dir)
  tf.gfile.MakeDirs(FLAGS.train_dir)
  train()


if __name__ == '__main__':
  tf.app.run()

127 行,启动 TensorFlow。

def main(argv=None):  # pylint: disable=unused-argument
  cifar10.maybe_download_and_extract()
  if tf.gfile.Exists(FLAGS.train_dir):
    tf.gfile.DeleteRecursively(FLAGS.train_dir)
  tf.gfile.MakeDirs(FLAGS.train_dir)
  train()


if __name__ == '__main__':
  tf.app.run()

118 ~ 123 行,启动 TensorFlow 后首先调用 main 函数,下载 cifar10 dataset,创建目录。

  with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
      images, labels = cifar10.distorted_inputs()

61行,使用默认图。
62 行, 全局步数变量。
68 行,获取训练数据和其对应的标签。

    logits = cifar10.inference(images)

    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

72 ~ 79 行,创建网络 Op, loss Op,训练 Op

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:

107 ~ 113 行,这里使用 MonitoredTrainingSession 可以设置钩子函数在开始训练之前,

每次运行之前,以及运行过程中设置回调函数处理变量,输出信息。

回调函数和钩子函数的区别
根本上,他们都是为了捕获消息而生的,但是钩子函数在捕获消息的第一时间就会执行,而回调函数是在整个捕获过程结束时,最后一个被执行的。

回调函数其实就是调用者把回调函数的函数指针传递给调用函数,当调用函数执行完毕时,通过函数指针来调用回调函数

      while not mon_sess.should_stop():
        mon_sess.run(train_op)

113 ~ 114 行,开始训练,传人 train Op 直到 FLAGS.max_steps 停止
上方代码有设置: “.train.StopAtStepHook(last_step=FLAGS.max_steps)

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

85 行,开始训练之前设置当前步数。
88 行,每次运行之前更新当前步数。
90 行,运行一个 step,传进 loss Op。 98 行, loss Op 的返回结果。

        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

93 ~ 105 行,运行时每 FLAGS.log_frequency(数字) 个 step 输出信息。


cifar10.py

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License a
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值