代码解析—part3 训练ADM—CVPR2023—Implicit Identity Leakage: The Stumbling Block to Improving Deepfake

论文讲解请看:https://blog.youkuaiyun.com/JustWantToLearn/article/details/138758033
代码链接:https://github.com/megvii-research/CADDM
在这里,我们简要描述算法流程,着重分析模型搭建细节,以及为什么要这样搭建。
part 1:数据集准备,请看链接 https://blog.youkuaiyun.com/JustWantToLearn/article/details/138773005
part 2: 数据集加载,包含 Multi-scale Facial Swap(MFS) 模块:https://blog.youkuaiyun.com/JustWantToLearn/article/details/139092687
part 3:训练过程,ADM模块,本文

1、训练 train.py

python train.py --cfg ./configs/caddm_train.cfg

def train():
    args = args_func()

    # load conifigs
    cfg = load_config(args.cfg)

    # init model. 模型初始化
    net = model.get(backbone=cfg['model']['backbone'])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = net.to(device)
    net = nn.DataParallel(net)

    # loss init loss初始化,多任务损失函数 MultiBoxLoss 和交叉熵损失函数 nn.CrossEntropyLoss
    det_criterion = MultiBoxLoss(
        cfg['det_loss']['num_classes'],
        cfg['det_loss']['overlap_thresh'],
        cfg['det_loss']['prior_for_matching'],
        cfg['det_loss']['bkg_label'],
        cfg['det_loss']['neg_mining'],
        cfg['det_loss']['neg_pos'],
        cfg['det_loss']['neg_overlap'],
        cfg['det_loss']['encode_target'],
        cfg['det_loss']['use_gpu']
    )
    criterion = nn.CrossEntropyLoss()

    # optimizer init.
    optimizer = optim.AdamW(net.parameters(), lr=1e-3, weight_decay=4e-3)

    # load checkpoint if given
    base_epoch = 0
    if args.ckpt:
        net, optimzer, base_epoch = load_checkpoint(args.ckpt, net, optimizer, device)

    # get training data 加载训练数据集
    print(f"Load deepfake dataset from {
     
     cfg['dataset']['img_path']}..")
    train_dataset = DeepfakeDataset('train', cfg)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg['train']['batch_size'],
                              shuffle=True, num_workers=4,
                              collate_fn=my_collate
                              )

    # start trining.进入训练模式,并循环遍历每个epoch和batch。在每个epoch开始时更新学习率
    net.train()
    for epoch in range(base_epoch, cfg['train']['epoch_num']):
        for index, (batch_data, batch_labels) in enumerate(train_loader):

            lr = update_learning_rate(epoch)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            labels, location_labels, confidence_labels = batch_labels
            labels = labels.long().to(device)
            location_labels = location_labels.to(device)
            confidence_labels = confidence_labels.long().to(device)
            #计算分类损失和检测损失。然后计算总损失,并执行反向传播
            optimizer.zero_grad()
            locations, confidence, outputs = net(batch_data)
            loss_end_cls = criterion(outputs, labels)
            loss_l, loss_c = det_criterion(
                (locations, confidence),
                confidence_labels, location_labels
            )
            acc = sum(outputs.max(-1).indices == labels).item() / labels.shape[0]
            det_loss = 0.1 * (loss_l + loss_c)
            loss = det_loss + loss_end_cls
            loss.backward()
            # 梯度裁剪和优化器步
            torch.nn.utils.clip_grad_value_(net.parameters(), 2)
            optimizer.step()

            outputs = [
                "e:{},iter: {}".format(epoch
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值