帮我写一个脚本将以下的数据集函数改造成视频处理脚本,我希望你用里面的取帧函数,将/common-data/kunlin/code/ReCamMaster-main/test_video/videos/ 都取最大帧数,然后重新按照原文件名保存成MP4:
class TextVideoCameraDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, args, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "videos", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames
self.frame_interval = frame_interval
self.num_frames = num_frames
self.height = height
self.width = width
self.is_i2v = is_i2v
self.args = args
self.cam_type = self.args.cam_type
self.frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.Resize(size=(height, width), antialias=True),
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def crop_and_resize(self, image):
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
return image
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
first_frame = None
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame)
if first_frame is None:
first_frame = np.array(frame)
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
if self.is_i2v:
return frames, first_frame
else:
return frames
def is_image(self, file_path):
file_ext_name = file_path.split(".")[-1]
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
return True
return False
def load_video(self, file_path):
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
return frames
def parse_matrix(self, matrix_str):
rows = matrix_str.strip().split('] [')
matrix = []
for row in rows:
row = row.replace('[', '').replace(']', '')
matrix.append(list(map(float, row.split())))
return np.array(matrix)
def get_relative_pose(self, cam_params):
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
cam_to_origin = 0
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
def __getitem__(self, data_id):
text = self.text[data_id]
path = self.path[data_id]
video = self.load_video(path)
if video is None:
raise ValueError(f"{path} is not a valid video.")
num_frames = video.shape[1]
assert num_frames == 81
data = {"text": text, "video": video, "path": path}
# load camera
tgt_camera_path = "/common-data/kunlin/code/ReCamMaster-main/example_test_data/cameras/camera_extrinsics.json"
with open(tgt_camera_path, 'r') as file:
cam_data = json.load(file)
cam_idx = list(range(num_frames))[::4]
traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{int(self.cam_type):02d}"]) for idx in cam_idx]
traj = np.stack(traj).transpose(0, 2, 1)
c2ws = []
for c2w in traj:
c2w = c2w[:, [1, 2, 0, 3]]
c2w[:3, 1] *= -1.
c2w[:3, 3] /= 100
c2ws.append(c2w)
tgt_cam_params = [Camera(cam_param) for cam_param in c2ws]
relative_poses = []
for i in range(len(tgt_cam_params)):
relative_pose = self.get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]])
relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1])
pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
data['camera'] = pose_embedding.to(torch.bfloat16)
return data
def __len__(self):
return len(self.path)
最新发布