论文讲解请看:https://blog.youkuaiyun.com/JustWantToLearn/article/details/138758033
代码链接:https://github.com/megvii-research/CADDM
在这里,我们简要描述算法流程,着重分析模型搭建细节,以及为什么要这样搭建。
part 1:数据集准备,请看链接 https://blog.youkuaiyun.com/JustWantToLearn/article/details/138773005
part 2: 数据集加载,包含 Multi-scale Facial Swap(MFS) 模块 本文
part 3:训练过程,ADM模块 https://blog.youkuaiyun.com/JustWantToLearn/article/details/139116455
文章目录
1、数据集加载
这里我在作者代码基础上做了一些修改,当landmark不存在时,跳过
最重要的函数:prepare_train_input
class DeepfakeDataset(Dataset):
r"""DeepfakeDataset Dataset.
The folder is expected to be organized as followed: root/cls/xxx.img_ext
Labels are indices of sorted classes in the root directory.
Args:
mode: train or test.
config: hypter parameters for processing images.
"""
def __init__(self, mode: str, config: dict):
super().__init__()
self.config = config
self.mode = mode
self.root = self.config['dataset']['img_path']
self.landmark_path = self.config['dataset']['ld_path']
self.rng = np.random
assert mode in ['train', 'test']
self.do_train = True if mode == 'train' else False
self.info_meta_dict = self.load_landmark_json(self.landmark_path)
self.class_dict = self.collect_class()
self.samples = self.collect_samples()
def load_landmark_json(self, landmark_json) -> Dict:
with open(landmark_json, 'r') as f:
landmark_dict = json.load(f)
return landmark_dict
def __getitem__(self, index: int) -> Tuple:
path, label_meta = self.samples[index] #获取样本
ld = np.array(label_meta['landmark'])#样本landmark
label = label_meta['labels']#样本标签
source_path = label_meta['source_path']# #样本初始数据
img = cv2.imread(path, cv2.IMREAD_COLOR)
source_img = cv2.imread(source_path, cv2.IMREAD_COLOR)
if self.mode == "train":
img, label_dict = prepare_train_input(
img, source_img, ld, label, self.config, self.do_train
)。#调用 prepare_train_input 函数处理训练样本
if isinstance(label_dict, str):
return None, label_dict
location_label = torch.Tensor(label_dict['location_label'])
confidence_label = torch.Tensor(label_dict['confidence_label'])
img = torch.Tensor(img.transpose(2, 0, 1))
return img, (label, location_label, confidence_label)
elif self.mode == 'test':
img, label_dict = prepare_test_input(
[img], ld, label, self.config
)#调用 prepare_test_input 函数处理测试样本
img = torch.Tensor(img[0].transpose(2, 0, 1))
video_name = label_meta['video_name']
return img, (label, video_name)
else:
raise ValueError("Unsupported mode of dataset!")
def __len__(self):
return len(self.samples)
1.1 收集样本 collect_samples
流程:构建每个文件的完整路径 path,并从路径中提取 info_key 和 video_name。
使用 info_key 从 info_meta_dict 获取记录信息 info_meta,包括 landmark、class_label 和 source_path。将路径和标记信息元组添加到 samples 列表中。
如果出现异常(如文件没有对应的标记信息),则增加 none_nums 计数器并打印。
def collect_samples(self) -> List:
samples = []
none_nums =0
directory = os.path.expanduser(self.root)
for key in sorted(self.class_dict.keys()):
d = os.path.join(directory, key)
if not os.path.isdir(d):
continue
for r, _, filename in sorted(os.walk(d, followlinks=True)):
for name in sorted(filename):
path = os.path.join(r, name)
info_key = path[:-4]
video_name = '/'.join(path.split('/')[:-1])
try:
info_meta = self.info_meta_dict[info_key]
landmark = info_meta['landmark']
class_label = int(info_meta['label'])
source_path = info_meta['source_path'] + path[-4:]
samples.append(
(path, {
'labels': class_label, 'landmark': landmark,
'source_path': source_path,
'video_name': video_name})
)
except:
none_nums+=1
print(none_nums)
return samples
1.2 收集类别 collect_class
def collect_class(self) -> Dict:
#使用 os.scandir 扫描根目录中的子目录,并获取子目录名称作为类别名称列表 classes
classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
#对 classes 进行降序排序
classes.sort(reverse=True)
#返回一个字典 class_dict,键为类别名称,值为类别索引
return {
classes[i]: np.int32(i) for i in range(len(classes))}
1.3 prepare_train_input
将 targetRgb 和 sourceRgb 图像存储在 images 列表中。
包含multi_scale_facial_swap、label_assign
def prepare_train_input(targetRgb, sourceRgb, landmark, label, config, training=True):
'''Prepare model input images.
Arguments:
targetRgb: original images or fake images.
sourceRgb: source images.
landmark: face landmark.
label: deepfake labels. genuine: 0, fake: 1.
config: deepfake config dict.
training: return processed image with aug or not.
'''
#targetRgb: 原始图像或伪造图像
#sourceRgb:源图像
#landmark:81个人脸标记点
rng = np.random
images = [targetRgb, sourceRgb]
#如果是训练模式且随机数大于等于 0.7,则对图像和标记进行 resize_aug 数据增强。
if training and rng.rand() >= 0.7:
images, landmark = resize_aug(images, landmark)
# multi-scale facial swap.
targetRgb, sourceRgb = images
# if input image is genuine.
mfs_result, bbox = targetRgb, np.zeros((1, 4))
# if input image is fake image. generate new fake image with mfs.
if label:#如果图片为假
#随机选择混合类型 blending_type 为 'poisson' 或 'alpha'。
blending_type = 'poisson' if rng.rand() >= 0.5 else 'alpha'
if rng.rand() >= 0.2:
#如果随机数大于等于 0.2,则执行全局人脸交换:
# global facial swap.
sliding_win = targetRgb.shape[:2]
if rng.rand()

最低0.47元/天 解锁文章
555

被折叠的 条评论
为什么被折叠?



