import cv2
import os
import glob
import torch
import time
import numpy as np
from torch.utils.data import Dataset,DataLoader
import concurrent.futures
class DatasetClass(Dataset):
def __init__(self,base_dir,dataset_name,train_or_test,clip_len=6,resize_height=256, resize_width=256, gray=False, scale=-1):
assert os.path.exists(base_dir)
assert clip_len >= 1
assert gray in [True, False]
assert scale in [0,-1]
dataset_dir = os.path.join(base_dir,"{}/{}ing/frames".format(dataset_name,train_or_test))
self.dataset_dir = dataset_dir
self.clip_len = clip_len
self.height = resize_height
self.width = resize_width
self.gray = gray
self.scale = scale
self.dataset_csv = self.get_dataset_csv()
def __getitem__(self, clip_index):
"""
:param clip_index:
:return: clip_data,shape=[clip_len,H,W,C]
"""
clip_frames = self.dataset_csv[clip_index].split("&&")
clip_data = []
for frame_path in clip_frames:
clip_data.append(np.expand_dims(self.np_load_frame(frame_path),axis=0))
clip_data = np.concatenate(clip_data, axis=0)
return clip_data
def __len__(self):
return len(self.dataset_csv)
def get_dataset_csv(self):
"""
:param self.dataset_dir
:param self.clip_len
:return: 一个文件列表,每个元素是:一个clip的视频帧路径的串联字符串,串联符号是“&&”
"""
videos = sorted(glob.glob(os.path.join(self.dataset_dir, "*"))) # 查看有多少视频
videos_frames = [sorted(glob.glob(os.path.join(v, "*.jpg"))) for v in videos] # 查看每个视频有多少视频帧
clips_csv = []
for video in videos_frames:
for i in range(len(video) - self.clip_len + 1):
clips_csv.append("&&".join(video[i:i + self.clip_len]))
return clips_csv
def np_load_frame(self,frame_path):
"""
:param frame_path:
:return: image.shape = [H,W,C]
"""
image = cv2.imread(frame_path)
image = cv2.resize(image, (self.width, self.height)).astype(dtype=np.float32)
if self.gray: # 转为灰度图
image = np.expand_dims(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), axis=-1)
if self.scale == 0: # 归一化到[0,1]
image = image/255.0
if self.scale == -1: # 归一化到[-1,1]
image = (image/127.5) - 1.0
return image
if __name__ =="__main__":
dataset_base_dir = "/media/wzg/Disk3T/Data"
dataset = DatasetClass(dataset_base_dir,"***","train", clip_len=10, resize_width=256, resize_height=256, gray=False, scale=-1)
"""使用方式1,迭代器"""
dataloader = DataLoader(dataset=dataset, batch_size=20, shuffle=True, num_workers=4, drop_last=False)
time_start = time.time()
data = next(iter(dataloader))
time_end = time.time()
time_cost = time_end - time_start
print(time_cost,data.shape)
"""使用方式2 逐个epoch提取"""
for epoch in range(1):
# 把dataloader放着这里有助于重新打乱数据
dataloader = DataLoader(dataset=dataset, batch_size=50, shuffle=True, num_workers=4, drop_last=False)
batch_num = len(dataloader) # 一个epoch的batch数
time_start = time.time()
for i,data in enumerate(dataloader):
print("epoch:{}, {}/{}".format(epoch,i,batch_num),data.shape)
time_end = time.time()
print(time_end - time_start)# 统计遍历完一个epoch的时间