论文链接:https://arxiv.org/abs/1608.00859
代码链接:https://github.com/yjxiong/tsn-pytorch
论文笔记链接:https://blog.youkuaiyun.com/qq_39862223/article/details/108419664
论文实验过程:https://blog.youkuaiyun.com/qq_39862223/article/details/108461526
论文源码分析:https://blog.youkuaiyun.com/qq_39862223/article/details/108486039
1 源码准备
在指定文件夹下,输入命令:
git clone --recursive https://github.com/yjxiong/tsn-pytorch
下载完成后,得到tsn-pytorch源码。
2 源码结构
下表列出tsn-pytorch中比较重要的文件:
文件名称 | 功能 |
---|---|
main.py | 训练脚本 |
test_models.py | 测试脚本 |
opts.py | 参数配置脚本 |
dataset.py | 数据读取脚本 |
models.py | 网络结构构建脚本 |
transforms.py | 数据预处理相关的脚本 |
tf_model_zoo文件夹 | 导入模型结构的脚本 |
接下来对一些重要文件,将一一讲解,并且说清数据流的走向和函数调用关系。
3. 源码分析(准备阶段)
3.1 数据准备
dataset.py的主要功能就是对数据集进行读取,并且对其稀疏采样,返回稀疏采样后得到的数据集。
它首先定义了一个类TSNDataSet,用来处理最原始的数据。该类返回的是torch.utils.data.Dataset类型,(注:一般而言在pytorch中自定义的数据读取类都要继承torch.utils.DataSet这个基类),然后通过重写_init_和_getitem_方法来读取函数。
(1)_init_函数
_init_函数的功能在于初始化TSNDataSet,并设置一些参数和参数默认值。
def __init__(self, root_path, list_file,
num_segments=3, new_length=1, modality='RGB',
image_tmpl='img_{:05d}.jpg', transform=None,
force_grayscale=False, random_shift=True, test_mode=False):
self.root_path = root_path
self.list_file = list_file
self.num_segments = num_segments
self.new_length = new_length
self.modality = modality
self.image_tmpl = image_tmpl
self.transform = transform
self.random_shift = random_shift
self.test_mode = test_mode
if self.modality == 'RGBDiff':
self.new_length += 1# Diff needs one more image to calculate diff
self._parse_list()
TSNDataSet类的初始化方法_init_需要如下参数:
- root_path : 项目的根目录地址,如果其他文件地址使用绝对地址,则可以写成" "
- list_file : 训练或测试的列表文件(.txt文件)地址
- num_segments : 视频分割的段数
- new_length : 根据输入数据集类型的不同,new_length取不同的值
- modality : 输入数据集类型(RGB、光流、RGB差异)
- image_tmpl : 图片的名称
- transform : 数据集是否进行变换操作
- random_shift : 稀疏采样时是否增加一个随机数
- test_mode : 是否是测试时的数据集输入
(2)_parse_list函数
_parse_list函数功能在于读取list文件,储存在video_list中
def _parse_list(self):
self.video_list = [VideoRecord(x.strip().split(' ')) for x in open(self.list_file)]
self.video_list是一个长度为训练数据数量的列表。每个值都是VIDEORecord对象,包含一个列表和3个属性,列表长度为3,用空格键分割,分别为帧路径、该视频含有多少帧和帧标签。
(3)_sample_indices函数
_sample_indices函数功能在于实现TSN的稀疏采样,返回的是稀疏采样的帧数列表
def _sample_indices(self, record):
average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
if average_duration > 0:
offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)
elif record.num_frames > self.num_segments:
offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
else:
offsets = np.zeros((self.num_segments,))
return offsets + 1
假设一个视频共有125帧,num_segments=3,输入模态为RGB,稀疏采样的步骤如下:
-
将视频分成num_segments=3段。根据代码,record.num_frames=150,self.new_length=1,求出平均每段的帧数为50帧。
-
定义一个list类型的变量offset,首先取第一个片段里的帧,假设随机数randint(average_duration,size=self.num_segments)=10,第一个片段时range(self.num_segments)=0,计算可得第一个片段中取到的帧编号为10
-
同理可获得其他片段中取到帧的编号,假设第二帧时,随机数取12,第三帧时,随机数取15,计算可得第二个、第三个片段中取到的帧编号,分别为62,115
-
经过上述计算,列表offset=[10,62,115],当返回时,返回的为offset+1,即真正取到的帧数为[11,63,116]
(4)_getitem_函数
该函数会在TSNDataSet初始化之后执行,功能在于调用执行稀疏采样的函数_sample_indices,并且调用get方法,得到TSNDataSet的返回
def __getitem__(self, index):
record = self.video_list[index]
if not self.test_mode:
segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
else:
segment_indices = self._get_test_indices(record)
return self.get(record, segment_indices)
record变量读取的是video_list的第index个数据,包含该视频所在的文件地址、视频包含的帧数和视频所属的分类。如果该TSNDataSet不是为测试部分运行的,则对_sample_indices(record)或_get_val_indices(record)运行,判断条件在于它是否为训练数据集,如果是,则执行前者,否则,执行后者。将稀疏采样获得的帧列表保存于segment_indices中,之后调用get()方法,作为其中的参数。
(5)get函数
get方法的功能在于读取提取的帧图片,并且对帧图片进行变形操作(角裁剪、中心提取等)
def get(self, record, indices):
images = list()
for seg_ind