C3D源码解析
论文链接:http://vlg.cs.dartmouth.edu/c3d/c3d_video.pdf
代码链接:https://github.com/jfzhang95/pytorch-video-recognition
1.源码准备
git clone --recursive https://github.com/jfzhang95/pytorch-video-recognition.git
下载完成后得到C3D源码
2.源码结构
| 文件名称 | 功能 |
|---|---|
| train.py | 训练脚本 |
| mypath.py | 配置数据集和预训练模型的路径 |
| dataest.py | 数据读取和数据处理脚本 |
| C3D_model.py | C3D模型网络结构构建脚本 |
| ucf101-caffe.path | 预训练模型 |
接下来对一些重要文件, 将一一讲解,并且说清楚数据流的走向和函数调用关系。
3.源码分析(准备阶段)
3.1 数据准备
dataset.py的主要功能是对数据集进行读取,对数据集进行处理,获取对应的帧图片数据集和对应的动作、标签相对应的文档。
它首先定义了一个类VideoDataset,用来处理最原始的数据。该类返回的是torch.utils.data.Dataset类型,(注:一般而言在pytorch中自定义的数据读取类都要继承torch.utils.DataSet这个基类),然后通过重写_init_和_getitem_方法来读取函数。
(1)__init__函数
__init__函数的功能大致分为以下三个部分(1. 初始化类VideoDataset,并设置一些参数和参数默认值; 2. 生成视频对应的帧视频数据集; 3. 生成视频动作标签的txt文档–看着有点乱,有心的话可以自己封装一下),还有一些定义的函数,下面会逐步讲解。
第一部分:初始化类VideoDataset,并设置一些参数和参数默认值;
def __init__(self, dataset='ucf101', split='train', clip_len=16, preprocess=False):
self.root_dir, self.output_dir = Path.db_dir(dataset) #获取数据集的源路径和输出路径
folder = os.path.join(self.output_dir, split) # 获取对应分组的的路径
self.clip_len = clip_len # 16帧图片的意思
self.split = split # 有三组 train val test
# The following three parameters are chosen as described in the paper section 4.1
# 图片的高和宽的变化过程(h*w-->128*171-->112*112)
self.resize_height = 128
self.resize_width = 171
self.crop_size = 112
第二部分: 生成视频对应的帧视频数据集;
# check_integrity()判断是否存在Dataset的源路径,若不存在,则报错
if not self.check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You need to download it from official website.')
# check_preprocess()判断是否存在Dataset的输出路径,若不存在preprocess()则创建,并在其中生成对应的帧图片的数据集
if (not self.check_preprocess()) or preprocess:
print('Preprocessing of {} dataset, this will take long, but it will be done only once.'.format(dataset))
self.preprocess()
第三部分: 生成视频动作标签的txt文档;
# Obtain all the filenames of files inside all the class folders
# Going through each class folder one at a time
# fnames-->所有类别里的动作视频的集合; labels-->动作视频对应的标签
self.fnames, labels = [], []
for label in sorted(os.listdir(folder)):
for fname in os.listdir(os.path.join(folder, label)):
self.fnames.append(os.path.join(folder, label, fname))
labels.append(label)
assert len(labels) == len(self.fnames)
print('Number of {} videos: {:d}'.format(split, len(self.fnames)))
# Prepare a mapping between the label names (strings) and indices (ints)--> label和对应的数字标签
self.label2index = {
label: index for index, label in enumerate(sorted(set(labels)))}
# Convert the list of label names into an array of label indices-->转化为数字标签
self.label_array = np.array([self.label2index[label] for label in labels], dtype=int)
# 生成对应的动作和数字标签的txt文档
if dataset == "ucf101":
if not os.path.exists('dataloaders/ucf_labels.txt'):
with open('dataloaders/ucf_labels.txt', 'w') as f:
for id, label in enumerate(sorted(self.label2index)):
f.writelines(str(id+1) + ' ' + label + '\n')
elif dataset == 'hmdb51':
if not os.path.exists('dataloaders/hmdb_labels.txt'):
with open('dataloaders/hmdb_labels.txt', 'w') as f:
for id, label in enumerate(sorted(self.label2index)):
f.writelines(str(id+1) + ' ' + label + '\n')
接下来介绍一些VideoDataset类的重要函数:
(2)__len__函数:
# 返回所有动作视频的总数
def __len__(self):
return len(self.fnames)
(3)__getitem__函数:
def __getitem__(self, index):
# Loading and preprocessing.
buffer = self.load_frames(self.fnames[index]) #加载一个视频生成的帧图片[frames,h,w,3]-->[frames,128,171,3]
buffer = self.crop(buffer, self.clip_len, self.crop_size) # [16,112,112,3]
labels = np.array(self.label_array[index]) # 转化为数组
if self.split == 'test':
# Perform data augmentation
buffer = self.randomflip(buffer) # 增强数据集
buffer = self.normalize(buffer) # 归一化
buffer = self.to_tensor(buffer) # [3,16,112,112]
return torch.from_numpy(buffer), torch.from_numpy(labels) #以数组的形式返回
(4)check_intergrity函数:
# check_integrity()判断是否存在Dataset的源路径,若不存在,则报错
def check_integrity(self):
if not os.path.exists(self.root_dir):
return False
else:
return True
(5)chech_preprocess函数:
# 检查输出路径是否存在,若不存在,则报错;检查输出路径的数据集图片格式是否正确,若不正确则报错
def check_preprocess(self):
# TODO: Check image size in output_dir
if not os.path.exists(self.output_dir):
return False
elif not os.path.exists(os.path.join(self.output_dir, 'train')):
return False
for ii, video_class in enumerate(os.listdir(os.path.join(self.output_dir, 'train'))):
for video in os.listdir(os.path.join(self.output_dir, 'train', video_class)):
video_name = os.path.join(os.path.join(self.output_dir, 'train', video_class

最低0.47元/天 解锁文章
3194

被折叠的 条评论
为什么被折叠?



