用自己的数据集训练TimeSformer并转ONNX用c++推理

用自己的数据集训练TimeSformer并转ONNX用c++推理

下载安装TimeSformer

TimeSformer开源地址

按照官方教程安装好环境。
如果报下面这个错误,是因为新版的pytorch已经不支持那种写法了,需要修改一下。

ImportError: cannot import name '_LinearWithBias' from 'torch.nn.modules.linear'

可以参考这个人的fork修改

创建分类文件夹

我这里有61个动作分类,每个分类创建一个文件夹
在这里插入图片描述
将视频文件分割成 每个视频大概10s左右;
然后将视频文件按照分类放到每个文件夹里。

创建数据集

写一个脚本分割数据集,并生成标签文件

import os
import csv
import shutil
from tqdm import tqdm
from sklearn.model_selection import train_test_split

out_dir = "/home/disk/liangbaikai/TimeSformer/mydata/mydatasets"  # 输出路径
video_path = "/home/disk/liangbaikai/TimeSformer/mydata/myvideos" # 数据集路径
file_name = ".csv"
name_list = ["train","test","val"]

if not os.path.exists(out_dir):
    os.mkdir(out_dir)
if not os.path.exists(os.path.join(out_dir, 'train')):
    os.mkdir(os.path.join(out_dir, 'train'))
if not os.path.exists(os.path.join(out_dir, 'val')):
    os.mkdir(os.path.join(out_dir, 'val'))
if not os.path.exists(os.path.join(out_dir, 'test')):
    os.mkdir(os.path.join(out_dir, 'test'))

for file in os.listdir(video_path):
        file_path = os.path.join(video_path, file)
        video_files = [name for name in os.listdir(file_path)]
        #将20%的数据分配给test
        train_and_valid, test = train_test_split(video_files, test_size=0.2, random_state=42)
        #将80%的数据再分配20%出来给val,剩下的给train
        train, val = train_test_split(train_and_valid, test_size=0.2, random_state=42)
        train_dir = os.path.join(out_dir, 'train', file)
        val_dir = os.path.join(out_dir, 'val', file)
        test_dir = os.path.join(out_dir, 'test', file)
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)
        if not os.path.exists(val_dir):
            os.mkdir(val_dir)
        if not os.path.exists(test_dir):
            os.mkdir(test_dir)
        for video in tqdm(train):
           shutil.copy(os.path.join(video_path,file,video),os.path.join(train_dir,video))
        for video in tqdm(test):
            shutil.copy(os.path.join(video_path,file,video),os.path.join(test_dir,video))
        for video in tqdm(val):
            shutil.copy(os.path.join(video_path,file,video),os.path.join(val_dir,video))

#输出路径下创建csv文件夹,并在文件夹下创建train.csv val.csv test.csv
csv_path = os.path.join(out_dir,"csv")
if not os.path.exists(csv_path):
    os.mkdir(csv_path)
    for name in name_list:
        with open(os.path.join(csv_path,name+file_name),'wb') as f:
            print("创建"+os.path.join(csv_path,name+file_name))



for ii in os.listdir(csv_path):
    if ii.split(".")[0] in name_list:
        path1 = os.path.join(csv_path,ii)
        with open(path1, 'w', newline='') as f:
            for dd in os.listdir(out_dir):
                if dd==ii.split(".")[0]:
                    for zz in os.listdir(os.path.join(out_dir,dd)):
                        for mm in os.listdir(os.path.join(out_dir,dd,zz)):
                            writer = csv.writer(f)
                            writer.writerow([os.path.join(out_dir,dd,zz,mm),zz])

## 创建类别label标号文件
labels= []
for label in sorted(os.listdir(video_path)):
    labels.append(label)
label2index = {
   label: index for index, label in enumerate(sorted(set(labels)))}
label_file = os.path.join(out_dir, str(len(os.listdir(video_path))) + 'class_labels.txt')
with open(label_file, 'w') as f:
    for id, label in enumerate(sorted(label2index)):
        f.writelines(str(id) + ' ' + label +'\n')

#替换csv文件中类别名为数字
csv_file = os.path.join(out_dir,"csv")
def txt_read(files):
    txt_dict = {
   }
    fopen = open(files)
    for line in fopen.readlines():
        line = str(line).replace('\n','')
        txt_dict[line.split(' ',1)[1]] = line.split(' ',1)[0]      
    fopen.close()
    return txt_dict
txt_dict = txt_read(label_file)
print(txt_dict)

for ii in os.listdir(csv_file):
    path1 = os.path.join(csv_file,ii)
    r = csv.reader(open(path1))
    lines = [l for l in r]
    for i in range(
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值