Pytorch Dataset类和Dataloader实践

此篇博客介绍了一个用于处理视频数据的Python类,DatasetClass,它读取指定目录下的视频剪辑,进行裁剪、缩放和可能的灰度转换,并提供两种数据加载方式。核心内容涉及文件操作、数据增强和PyTorch Dataset实现。
import cv2
import os
import glob
import torch
import time
import numpy as np
from torch.utils.data import Dataset,DataLoader
import concurrent.futures

class DatasetClass(Dataset):
    def __init__(self,base_dir,dataset_name,train_or_test,clip_len=6,resize_height=256, resize_width=256, gray=False, scale=-1):
        assert os.path.exists(base_dir)
        assert clip_len >= 1
        assert gray in [True, False]
        assert scale in [0,-1]
        dataset_dir = os.path.join(base_dir,"{}/{}ing/frames".format(dataset_name,train_or_test))
        self.dataset_dir = dataset_dir
        self.clip_len = clip_len
        self.height = resize_height
        self.width = resize_width
        self.gray = gray
        self.scale = scale
        self.dataset_csv = self.get_dataset_csv()
        
    def __getitem__(self, clip_index):
        """
        :param clip_index:
        :return: clip_data,shape=[clip_len,H,W,C]
        """
        clip_frames = self.dataset_csv[clip_index].split("&&")
        clip_data = []
        for frame_path in clip_frames:
            clip_data.append(np.expand_dims(self.np_load_frame(frame_path),axis=0))
        clip_data = np.concatenate(clip_data, axis=0)
        return clip_data
    
    def __len__(self):
        return len(self.dataset_csv)

    def get_dataset_csv(self):
        """
        :param self.dataset_dir
        :param self.clip_len
        :return: 一个文件列表,每个元素是:一个clip的视频帧路径的串联字符串,串联符号是“&&”
        """
        videos = sorted(glob.glob(os.path.join(self.dataset_dir, "*")))  # 查看有多少视频
        videos_frames = [sorted(glob.glob(os.path.join(v, "*.jpg"))) for v in videos]  # 查看每个视频有多少视频帧
        clips_csv = []
        for video in videos_frames:
            for i in range(len(video) - self.clip_len + 1):
                clips_csv.append("&&".join(video[i:i + self.clip_len]))
        return clips_csv

    def np_load_frame(self,frame_path):
        """
        :param frame_path:
        :return: image.shape = [H,W,C]
        """
        image = cv2.imread(frame_path)
        image = cv2.resize(image, (self.width, self.height)).astype(dtype=np.float32)
        if self.gray: # 转为灰度图
            image = np.expand_dims(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), axis=-1)
        if self.scale == 0: # 归一化到[0,1]
            image = image/255.0
        if self.scale == -1: # 归一化到[-1,1]
            image = (image/127.5) - 1.0
        return image

        
        
if __name__ =="__main__":
    dataset_base_dir = "/media/wzg/Disk3T/Data"
    dataset = DatasetClass(dataset_base_dir,"***","train", clip_len=10, resize_width=256, resize_height=256, gray=False, scale=-1)
    
    """使用方式1,迭代器"""
    dataloader = DataLoader(dataset=dataset, batch_size=20, shuffle=True, num_workers=4, drop_last=False)
    time_start = time.time()
    data = next(iter(dataloader))
    time_end = time.time()
    time_cost = time_end - time_start
    print(time_cost,data.shape)
    
    """使用方式2 逐个epoch提取"""
    for epoch in range(1):
        # 把dataloader放着这里有助于重新打乱数据
        dataloader = DataLoader(dataset=dataset, batch_size=50, shuffle=True, num_workers=4, drop_last=False)
        batch_num = len(dataloader)  # 一个epoch的batch数
        time_start = time.time()
        for i,data in enumerate(dataloader):
            print("epoch:{}, {}/{}".format(epoch,i,batch_num),data.shape)
        time_end = time.time()
        print(time_end - time_start)# 统计遍历完一个epoch的时间

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值