TSN源码分析

本文详细分析了TSN源码,包括数据准备、模型设计及训练部分。介绍了`dataset.py`中数据集的读取、稀疏采样方法,以及`models.py`中TSN模型的构建过程,如基础模型的选择和调整。训练部分讲解了`AverageMeter`类、`accuracy`函数、`train`函数和`main`函数,阐述了模型训练的完整流程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在这里插入图片描述
论文链接:https://arxiv.org/abs/1608.00859
代码链接:https://github.com/yjxiong/tsn-pytorch
论文笔记链接:https://blog.youkuaiyun.com/qq_39862223/article/details/108419664
论文实验过程:https://blog.youkuaiyun.com/qq_39862223/article/details/108461526
论文源码分析:https://blog.youkuaiyun.com/qq_39862223/article/details/108486039

1 源码准备

在指定文件夹下,输入命令:

git clone --recursive https://github.com/yjxiong/tsn-pytorch 

下载完成后,得到tsn-pytorch源码。

2 源码结构

下表列出tsn-pytorch中比较重要的文件:

文件名称 功能
main.py 训练脚本
test_models.py 测试脚本
opts.py 参数配置脚本
dataset.py 数据读取脚本
models.py 网络结构构建脚本
transforms.py 数据预处理相关的脚本
tf_model_zoo文件夹 导入模型结构的脚本

接下来对一些重要文件,将一一讲解,并且说清数据流的走向和函数调用关系。

3. 源码分析(准备阶段)

3.1 数据准备

dataset.py的主要功能就是对数据集进行读取,并且对其稀疏采样,返回稀疏采样后得到的数据集。

它首先定义了一个类TSNDataSet,用来处理最原始的数据。该类返回的是torch.utils.data.Dataset类型,(:一般而言在pytorch中自定义的数据读取类都要继承torch.utils.DataSet这个基类),然后通过重写_init_和_getitem_方法来读取函数。

(1)_init_函数

_init_函数的功能在于初始化TSNDataSet,并设置一些参数和参数默认值。

def __init__(self, root_path, list_file,
                 num_segments=3, new_length=1, modality='RGB',
                 image_tmpl='img_{:05d}.jpg', transform=None,
                 force_grayscale=False, random_shift=True, test_mode=False):
        self.root_path = root_path
        self.list_file = list_file
        self.num_segments = num_segments
        self.new_length = new_length
        self.modality = modality
        self.image_tmpl = image_tmpl
        self.transform = transform
        self.random_shift = random_shift
        self.test_mode = test_mode

        if self.modality == 'RGBDiff':
            self.new_length += 1# Diff needs one more image to calculate diff

        self._parse_list()

TSNDataSet类的初始化方法_init_需要如下参数:

  • root_path : 项目的根目录地址,如果其他文件地址使用绝对地址,则可以写成" "
  • list_file : 训练或测试的列表文件(.txt文件)地址
  • num_segments : 视频分割的段数
  • new_length : 根据输入数据集类型的不同,new_length取不同的值
  • modality : 输入数据集类型(RGB、光流、RGB差异)
  • image_tmpl : 图片的名称
  • transform : 数据集是否进行变换操作
  • random_shift : 稀疏采样时是否增加一个随机数
  • test_mode : 是否是测试时的数据集输入
(2)_parse_list函数

_parse_list函数功能在于读取list文件,储存在video_list中

    def _parse_list(self):
        self.video_list = [VideoRecord(x.strip().split(' ')) for x in open(self.list_file)]

self.video_list是一个长度为训练数据数量的列表。每个值都是VIDEORecord对象,包含一个列表和3个属性,列表长度为3,用空格键分割,分别为帧路径、该视频含有多少帧和帧标签。

(3)_sample_indices函数

_sample_indices函数功能在于实现TSN的稀疏采样,返回的是稀疏采样的帧数列表

    def _sample_indices(self, record):
        average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
        if average_duration > 0:
            offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)
        elif record.num_frames > self.num_segments:
            offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
        else:
            offsets = np.zeros((self.num_segments,))
        return offsets + 1

假设一个视频共有125帧,num_segments=3,输入模态为RGB,稀疏采样的步骤如下:

  • 将视频分成num_segments=3段。根据代码,record.num_frames=150,self.new_length=1,求出平均每段的帧数为50帧。
    在这里插入图片描述

  • 定义一个list类型的变量offset,首先取第一个片段里的帧,假设随机数randint(average_duration,size=self.num_segments)=10,第一个片段时range(self.num_segments)=0,计算可得第一个片段中取到的帧编号为10
    在这里插入图片描述

  • 同理可获得其他片段中取到帧的编号,假设第二帧时,随机数取12,第三帧时,随机数取15,计算可得第二个、第三个片段中取到的帧编号,分别为62,115
    在这里插入图片描述

  • 经过上述计算,列表offset=[10,62,115],当返回时,返回的为offset+1,即真正取到的帧数为[11,63,116]

(4)_getitem_函数

该函数会在TSNDataSet初始化之后执行,功能在于调用执行稀疏采样的函数_sample_indices,并且调用get方法,得到TSNDataSet的返回

def __getitem__(self, index):
        record = self.video_list[index]
        if not self.test_mode:
            segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
        else:
            segment_indices = self._get_test_indices(record)

        return self.get(record, segment_indices)

record变量读取的是video_list的第index个数据,包含该视频所在的文件地址、视频包含的帧数和视频所属的分类。如果该TSNDataSet不是为测试部分运行的,则对_sample_indices(record)或_get_val_indices(record)运行,判断条件在于它是否为训练数据集,如果是,则执行前者,否则,执行后者。将稀疏采样获得的帧列表保存于segment_indices中,之后调用get()方法,作为其中的参数。

(5)get函数

get方法的功能在于读取提取的帧图片,并且对帧图片进行变形操作(角裁剪、中心提取等)

def get(self, record, indices):

        images = list()
        for seg_ind 
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值