C3D源码解析

C3D源码解析

论文链接:http://vlg.cs.dartmouth.edu/c3d/c3d_video.pdf

代码链接:https://github.com/jfzhang95/pytorch-video-recognition

1.源码准备

git clone --recursive https://github.com/jfzhang95/pytorch-video-recognition.git

下载完成后得到C3D源码

2.源码结构

文件名称 功能
train.py 训练脚本
mypath.py 配置数据集和预训练模型的路径
dataest.py 数据读取和数据处理脚本
C3D_model.py C3D模型网络结构构建脚本
ucf101-caffe.path 预训练模型

接下来对一些重要文件, 将一一讲解,并且说清楚数据流的走向和函数调用关系。

3.源码分析(准备阶段)

3.1 数据准备

dataset.py的主要功能是对数据集进行读取,对数据集进行处理,获取对应的帧图片数据集和对应的动作、标签相对应的文档。

它首先定义了一个类VideoDataset,用来处理最原始的数据。该类返回的是torch.utils.data.Dataset类型,(:一般而言在pytorch中自定义的数据读取类都要继承torch.utils.DataSet这个基类),然后通过重写_init_和_getitem_方法来读取函数。

(1)__init__函数

__init__函数的功能大致分为以下三个部分(1. 初始化类VideoDataset,并设置一些参数和参数默认值; 2. 生成视频对应的帧视频数据集; 3. 生成视频动作标签的txt文档–看着有点乱,有心的话可以自己封装一下),还有一些定义的函数,下面会逐步讲解。

第一部分:初始化类VideoDataset,并设置一些参数和参数默认值;

	def __init__(self, dataset='ucf101', split='train', clip_len=16, preprocess=False):	
     	self.root_dir, self.output_dir = Path.db_dir(dataset) #获取数据集的源路径和输出路径
        folder = os.path.join(self.output_dir, split) # 获取对应分组的的路径
        self.clip_len = clip_len # 16帧图片的意思
        self.split = split # 有三组 train val test

        # The following three parameters are chosen as described in the paper section 4.1
        # 图片的高和宽的变化过程(h*w-->128*171-->112*112)
        self.resize_height = 128
        self.resize_width = 171
        self.crop_size = 112

第二部分: 生成视频对应的帧视频数据集;

# check_integrity()判断是否存在Dataset的源路径,若不存在,则报错
        if not self.check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You need to download it from official website.')
        # check_preprocess()判断是否存在Dataset的输出路径,若不存在preprocess()则创建,并在其中生成对应的帧图片的数据集
        if (not self.check_preprocess()) or preprocess:
            print('Preprocessing of {} dataset, this will take long, but it will be done only once.'.format(dataset))
            self.preprocess() 

第三部分: 生成视频动作标签的txt文档;

# Obtain all the filenames of files inside all the class folders
        # Going through each class folder one at a time
    # fnames-->所有类别里的动作视频的集合; labels-->动作视频对应的标签
        self.fnames, labels = [], []
        for label in sorted(os.listdir(folder)):
            for fname in os.listdir(os.path.join(folder, label)):
                self.fnames.append(os.path.join(folder, label, fname))
                labels.append(label)

        assert len(labels) == len(self.fnames)
        print('Number of {} videos: {:d}'.format(split, len(self.fnames)))

        # Prepare a mapping between the label names (strings) and indices (ints)--> label和对应的数字标签
        self.label2index = {
   
   label: index for index, label in enumerate(sorted(set(labels)))}
        # Convert the list of label names into an array of label indices-->转化为数字标签
        self.label_array = np.array([self.label2index[label] for label in labels], dtype=int)
		# 生成对应的动作和数字标签的txt文档
        if dataset == "ucf101":
            if not os.path.exists('dataloaders/ucf_labels.txt'):
                with open('dataloaders/ucf_labels.txt', 'w') as f:
                    for id, label in enumerate(sorted(self.label2index)):
                        f.writelines(str(id+1) + ' ' + label + '\n')

        elif dataset == 'hmdb51':
            if not os.path.exists('dataloaders/hmdb_labels.txt'):
                with open('dataloaders/hmdb_labels.txt', 'w') as f:
                    for id, label in enumerate(sorted(self.label2index)):
                        f.writelines(str(id+1) + ' ' + label + '\n')

接下来介绍一些VideoDataset类的重要函数:

(2)__len__函数:
    # 返回所有动作视频的总数
    def __len__(self):
        return len(self.fnames)
(3)__getitem__函数:
    def __getitem__(self, index):
        # Loading and preprocessing.
        buffer = self.load_frames(self.fnames[index]) #加载一个视频生成的帧图片[frames,h,w,3]-->[frames,128,171,3]
        buffer = self.crop(buffer, self.clip_len, self.crop_size) # [16,112,112,3]
        labels = np.array(self.label_array[index]) #  转化为数组

        if self.split == 'test':
            # Perform data augmentation
            buffer = self.randomflip(buffer) # 增强数据集
        buffer = self.normalize(buffer) # 归一化
        buffer = self.to_tensor(buffer) # [3,16,112,112]
        return torch.from_numpy(buffer), torch.from_numpy(labels) #以数组的形式返回
(4)check_intergrity函数:
    # check_integrity()判断是否存在Dataset的源路径,若不存在,则报错
    def check_integrity(self):
        if not os.path.exists(self.root_dir):
            return False
        else:
            return True
(5)chech_preprocess函数:
    # 检查输出路径是否存在,若不存在,则报错;检查输出路径的数据集图片格式是否正确,若不正确则报错
    def check_preprocess(self):
        # TODO: Check image size in output_dir
        if not os.path.exists(self.output_dir):
            return False
        elif not os.path.exists(os.path.join(self.output_dir, 'train')):
            return False

        for ii, video_class in enumerate(os.listdir(os.path.join(self.output_dir, 'train'))):
            for video in os.listdir(os.path.join(self.output_dir, 'train', video_class)):
                video_name = os.path.join(os.path.join(self.output_dir, 'train', video_class
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值