tensorflow 中batch normalize(slim方式) 的使用

本文介绍了一个基于TensorFlow的卷积神经网络(CNN)实现案例,该模型使用批量化正则化技术来提高训练效果,并通过MNIST数据集进行训练与验证。文中详细展示了如何构建网络结构、定义损失函数及优化器等关键步骤。

具体参看这篇博客:https://blog.youkuaiyun.com/jiruiYang/article/details/77202674
说的不错,而且这份githun代码值得借鉴:https://github.com/soloice/mnist-bn

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import os

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python.ops import control_flow_ops

FLAGS = None


def model():
    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])
    keep_prob = tf.placeholder(tf.float32, [])
    y_ = tf.placeholder(tf.float32, [None, 10])
    is_training = tf.placeholder(tf.bool, [])
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    with slim.arg_scope([slim.conv2d, slim.fully_connected],
                        activation_fn=tf.nn.crelu,
                        normalizer_fn=slim.batch_norm,
                        normalizer_params={'is_training': is_training, 'decay': 0.95}):
        conv1 = slim.conv2d(x_image, 16, [5, 5], scope='conv1')
        pool1 = slim.max_pool2d(conv1, [2, 2], scope='pool1')
        conv2 = slim.conv2d(pool1, 32, [5, 5], scope='conv2')
        pool2 = slim.max_pool2d(conv2, [2, 2], scope='pool2')
        flatten = slim.flatten(pool2)
        fc = slim.fully_connected(flatten, 1024, scope='fc1')
        drop = slim.dropout(fc, keep_prob=keep_prob)
        logits = slim.fully_connected(drop, 10, activation_fn=None, scope='logits')

    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits))

    step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
    train_step = slim.learning.create_train_op(cross_entropy, optimizer, global_step=step)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
        updates = tf.group(*update_ops)
        cross_entropy = control_flow_ops.with_dependencies([updates], cross_entropy)

    # Add summaries for BN variables
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.scalar('cross_entropy', cross_entropy)
    for v in tf.all_variables():
        if v.name.startswith('conv1/Batch') or v.name.startswith('conv2/Batch') or \
                v.name.startswith('fc1/Batch') or v.name.startswith('logits/Batch'):
            print(v.name)
            tf.summary.histogram(v.name, v)
    merged_summary_op = tf.summary.merge_all()

    return {'x': x,
            'y_': y_,
            'keep_prob': keep_prob,
            'is_training': is_training,
            'train_step': train_step,
            'global_step': step,
            'accuracy': accuracy,
            'cross_entropy': cross_entropy,
            'summary': merged_summary_op}


def train():
    # clear checkpoint directory
    print('Clearing existed checkpoints and logs')
    for root, sub_folder, file_list in os.walk(FLAGS.checkpoint_dir):
        for f in file_list:
            os.remove(os.path.join(root, f))
    for root, sub_folder, file_list in os.walk(FLAGS.train_log_dir):
        for f in file_list:
            os.remove(os.path.join(root, f))

    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    net = model()
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    train_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_log_dir, 'train'), sess.graph)
    valid_writer = tf.summary.FileWriter(os.path.join(FLAGS.train_log_dir, 'valid'), sess.graph)

    # Train
    batch_size = FLAGS.batch_size
    for i in range(10001):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        train_dict = {net['x']: batch_xs,
                      net['y_']: batch_ys,
                      net['keep_prob']: 0.5,
                      net['is_training']: True}
        step, _ = sess.run([net['global_step'], net['train_step']], feed_dict=train_dict)
        if step % 50 == 0:
            train_dict = {net['x']: batch_xs,
                          net['y_']: batch_ys,
                          net['keep_prob']: 1.0,
                          net['is_training']: True}
            entropy, acc, summary = sess.run([net['cross_entropy'], net['accuracy'], net['summary']],
                                             feed_dict=train_dict)
            train_writer.add_summary(summary, global_step=step)
            print('Train step {}: entropy {}: accuracy {}'.format(step, entropy, acc))

            # Note: the validation error is erratic in the beginning (Maybe 2~3k steps).
            # This does NOT imply the batch normalization is buggy.
            # On the contrary, it's BN's dynamics: moving_mean/variance are not estimated that well in the beginning.
            valid_dict = {net['x']: batch_xs,
                          net['y_']: batch_ys,
                          net['keep_prob']: 1.0,
                          net['is_training']: False}
            entropy, acc, summary = sess.run([net['cross_entropy'], net['accuracy'], net['summary']],
                                             feed_dict=valid_dict)
            valid_writer.add_summary(summary, global_step=step)
            print('***** Valid step {}: entropy {}: accuracy {} *****'.format(step, entropy, acc))
    saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'mnist-conv-slim'))
    print('Finish training')

    # validation
    acc = 0.0
    batch_size = FLAGS.batch_size
    num_iter = 5000 // batch_size
    for i in range(num_iter):
        batch_xs, batch_ys = mnist.validation.next_batch(batch_size)
        test_dict = {net['x']: batch_xs,
                     net['y_']: batch_ys,
                     net['keep_prob']: 1.0,
                     net['is_training']: False}
        acc_ = sess.run(net['accuracy'], feed_dict=test_dict)
        acc += acc_
    print('Overall validation accuracy {}'.format(acc / num_iter))
    sess.close()


def test():
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    # Test trained model
    net = model()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    if ckpt:
        saver.restore(sess, ckpt)
        print("restore from the checkpoint {0}".format(ckpt))

    acc = 0.0
    batch_size = FLAGS.batch_size
    num_iter = 10000 // batch_size
    for i in range(num_iter):
        batch_xs, batch_ys = mnist.test.next_batch(batch_size)
        feed_dict = {net['x']: batch_xs,
                     net['y_']: batch_ys,
                     net['keep_prob']: 1.0,
                     net['is_training']: False}
        acc_ = sess.run(net['accuracy'], feed_dict=feed_dict)
        acc += acc_
    print('Overall test accuracy {}'.format(acc / num_iter))
    sess.close()


def main(_):
    if FLAGS.phase == 'train':
        train()
    else:
        test()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='MNIST_data',
                        help='Directory for storing input data')
    parser.add_argument('--phase', type=str, default='train',
                        help='Training or test phase, should be one of {"train", "test"}')
    parser.add_argument('--batch_size', type=int, default=50,
                        help='Training or test phase, should be one of {"train", "test"}')
    parser.add_argument('--train_log_dir', type=str, default='log',
                        help='Directory for logs')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
                        help='Directory for checkpoint file')
    FLAGS, unparsed = parser.parse_known_args()
    if not os.path.isdir(FLAGS.checkpoint_dir):
        os.mkdir(FLAGS.checkpoint_dir)
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
<think>好的,用户想了解或实现视频级别的行为识别技术。我需要根据提供的引用内容来组织回答。首先,引用[1]提到三种主要方法:two-stream、C3D和CNN-LSTM。这应该是回答的基础结构。 接下来,要详细解释每个方法的算法、架构、预处理等。比如,two-stream方法中的ActionVLAD池化层在引用[2]中有提到,需要说明其与传统池化的不同,以及如何使用时序信息。另外,引用[3]中的基准测试数据可以用于比较各方法的性能,比如I3D的98.0%准确率。 用户可能还需要知道如何选择模型、数据预处理和训练技巧。比如,双流网络需要光流数据,而C3D可能更适合短时序。还要提到预训练的重要性,如Kinetics数据集的作用。 最后,生成相关问题时要覆盖技术细节、比较、优化和应用场景,确保用户能进一步探索。需要确保回答结构清晰,分步骤说明实现方法,并引用正确的文献来源。</think># 视频级别行为识别技术实现方法 ## 一、核心方法体系 视频级别行为识别技术主要基于三大技术路线,以下是各方法的实现要点与技术演进: ### 1. Two-Stream双流方法 #### 算法架构 - **空间流**:处理RGB帧,使用VGG-16/ResNet提取外观特征 - **时间流**:处理光流帧堆栈,捕获运动特征 - **特征融合**:采用ActionVLAD池化层进行时空域特征聚合[^2] $$ \text{VLAD}(x) = \sum_{i=1}^{N} \sum_{k=1}^{K} a_k(x_i)(x_i - c_k) $$ 其中$c_k$为可学习的聚类中心,$a_k$为软分配权重 #### 实现步骤 1. 光流计算:使用TV-L1算法提取密集光流场 2. 帧采样策略:分段随机采样(Sparse Sampling) 3. 特征提取:ImageNet预训练模型初始化 4. 端到端训练:通过反向传播联合优化时空流参数 ### 2. 3D卷积网络(C3D) #### 网络结构 - 连续8-16帧作为3D卷积输入 - 卷积核尺寸$3×3×3$ - 典型结构:C3D(8层) → I3D(Inflated 3D)[^3] #### 关键改进 - 动力学数据集预训练(Kinetics-400) - 时空维度膨胀:将2D卷积核扩展为3D - 多尺度特征融合:3D版特征金字塔 ### 3. CNN-LSTM混合模型 #### 实现方案 - 前端CNN:每帧提取视觉特征$v_t \in \mathbb{R}^{2048}$ - LSTM时序建模: $$ h_t = \text{LSTM}(v_t, h_{t-1}) $$ - 注意力机制:时间域注意力权重计算 $$ \alpha_t = \text{softmax}(W_a h_t) $$ ## 二、实现关键要素 ### 数据预处理流程 ```python # 典型预处理代码示例 def process_video(video, clip_length=16): frames = extract_frames(video) # 25fps采样 optical_flow = compute_optical_flow(frames) # 光流计算 clips = temporal_sampling(frames, clip_length) # 分段采样 return normalize(clips) ``` ### 性能优化策略 | 策略 | 双流方法 | C3D | CNN-LSTM | |------|----------|-----|----------| | 预训练 | ImageNet | Kinetics | ImageNet+ActivityNet | | 时序建模 | ActionVLAD | 3D卷积堆叠 | Bi-LSTM | | 推理速度 | 83ms/视频 | 120ms/视频 | 210ms/视频 | ### 基准性能对比(UCF101数据集) - Two-Stream I3D:98.0%[^3] - ActionVLAD双流:93.4% - TSN(Temporal Segment Network):94.2% ## 三、工程实现建议 1. **硬件选型**:建议使用至少4*V100 GPU,batch_size=32时显存占用约18GB 2. **框架选择**:PyTorch推荐使用MMAction2,TensorFlow可选TF-Slim 3. **蒸馏优化**:使用I3D作为教师网络压缩模型尺寸
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值