def restore_rgb_img(batch_img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
"""
restore the rgb img
Args:
batch_img: rgb image with normalization, shape= (batch_size, num_frames, C, H, W)
mean: params to restore the rgb image
std: params to restore the rgb image
Returns:
"""
assert len(batch_img.shape) == 5
assert batch_img.shape[2] == 3
C = batch_img.shape[2]
for channel_index, channle_mean, channel_std in zip(range(C), mean, std):
batch_img[:, :, [channel_index], :, :] = batch_img[:, :, [channel_index], :, :]*channel_std + channle_mean
return batch_img
def vis_cos_dist(cos_dist, batch_idx, width