论文讲解请看:https://blog.youkuaiyun.com/JustWantToLearn/article/details/138758033
代码链接:https://github.com/megvii-research/CADDM
在这里,我们简要描述算法流程,着重分析模型搭建细节,以及为什么要这样搭建。
part 1:数据集准备,请看链接 https://blog.youkuaiyun.com/JustWantToLearn/article/details/138773005
part 2: 数据集加载,包含 Multi-scale Facial Swap(MFS) 模块:https://blog.youkuaiyun.com/JustWantToLearn/article/details/139092687
part 3:训练过程,ADM模块,本文
文章目录
1、训练 train.py
python train.py --cfg ./configs/caddm_train.cfg
def train():
args = args_func()
# load conifigs
cfg = load_config(args.cfg)
# init model. 模型初始化
net = model.get(backbone=cfg['model']['backbone'])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
net = nn.DataParallel(net)
# loss init loss初始化,多任务损失函数 MultiBoxLoss 和交叉熵损失函数 nn.CrossEntropyLoss
det_criterion = MultiBoxLoss(
cfg['det_loss']['num_classes'],
cfg['det_loss']['overlap_thresh'],
cfg['det_loss']['prior_for_matching'],
cfg['det_loss']['bkg_label'],
cfg['det_loss']['neg_mining'],
cfg['det_loss']['neg_pos'],
cfg['det_loss']['neg_overlap'],
cfg['det_loss']['encode_target'],
cfg['det_loss']['use_gpu']
)
criterion = nn.CrossEntropyLoss()
# optimizer init.
optimizer = optim.AdamW(net.parameters(), lr=1e-3, weight_decay=4e-3)
# load checkpoint if given
base_epoch = 0
if args.ckpt:
net, optimzer, base_epoch = load_checkpoint(args.ckpt, net, optimizer, device)
# get training data 加载训练数据集
print(f"Load deepfake dataset from {
cfg['dataset']['img_path']}..")
train_dataset = DeepfakeDataset('train', cfg)
train_loader = DataLoader(train_dataset,
batch_size=cfg['train']['batch_size'],
shuffle=True, num_workers=4,
collate_fn=my_collate
)
# start trining.进入训练模式,并循环遍历每个epoch和batch。在每个epoch开始时更新学习率
net.train()
for epoch in range(base_epoch, cfg['train']['epoch_num']):
for index, (batch_data, batch_labels) in enumerate(train_loader):
lr = update_learning_rate(epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
labels, location_labels, confidence_labels = batch_labels
labels = labels.long().to(device)
location_labels = location_labels.to(device)
confidence_labels = confidence_labels.long().to(device)
#计算分类损失和检测损失。然后计算总损失,并执行反向传播
optimizer.zero_grad()
locations, confidence, outputs = net(batch_data)
loss_end_cls = criterion(outputs, labels)
loss_l, loss_c = det_criterion(
(locations, confidence),
confidence_labels, location_labels
)
acc = sum(outputs.max(-1).indices == labels).item() / labels.shape[0]
det_loss = 0.1 * (loss_l + loss_c)
loss = det_loss + loss_end_cls
loss.backward()
# 梯度裁剪和优化器步
torch.nn.utils.clip_grad_value_(net.parameters(), 2)
optimizer.step()
outputs = [
"e:{},iter: {}".format(epoch