FGM对抗训练


前言

 对抗训练无论是在CV领域还是在NLP领域都具有举足轻重的地位,在NLP比赛中,抗训练确确实实能够提升模型在具体任务上的泛化性能。


一、什么是对抗训练?

对抗样本:对输入增加微小扰动得到的样本。旨在增加模型损失

对抗训练:训练模型去区分样例是真实样例还是对抗样本的过程。对抗训练不仅可以提升模型对对抗样本的防御能力,还能提升对原始样本的泛化能力

1、FGM——Fast Gradient Method

FSGM是每个方向上都走相同的一步,2017年Goodfellow后续提出的FGM则是根据具体的梯度进行scale,得到更好的对抗样本:

对于每个x:
  1.计算x的前向loss、反向传播得到梯度
  2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r
  3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上
  4.将embedding恢复为(1)时的值
  5.根据(3)的梯度对参数进行更新

Pytorch实现
class FGM():
    """ 快速梯度对抗训练
    """
    def __init__(self, model):
        self.model = model
        self.backup = {}

    def attack(self, epsilon=1., emb_name='word_embeddings'):
        # emb_name这个参数要换成你模型中embedding的参数名
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm != 0 and not torch.isnan(norm):
                    r_at = epsilon * param.grad / norm
                    param.data.add_(r_at)

    def restore(self, emb_name='word_embeddings'):
        # emb_name这个参数要换成你模型中embedding的参数名
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

一、使用步骤

  for step, batch in enumerate(train_dataloader):# 遍历批数据
    # Add batch to GPU
    batch = tuple(t.to(device) for t in batch)
    # Unpack the inputs from our dataloader
    # 每一批数据展开
    # train_inputs.extend(one_freq_input_ids)
    # train_labels.extend(one_freq_labels)
    # train_masks.extend(one_freq_attention_masks)
    # train_token_types.extend(one_freq_token_types)
    # 接收batch的输入
    b_input_ids, b_input_mask, b_labels, b_token_types = batch

    outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
    logits = outputs[0]
    loss_func = BCEWithLogitsLoss() # 计算损失
    loss = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels)) 
    train_loss_set.append(loss.item())# 记录loss    

    # Backward pass
    loss.backward(retain_graph=True) # loss反向求导
    #对抗训练
    fgm.attack()
    loss_adv = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels))
    loss_adv.backward(retain_graph=True)
    fgm.restore()
    #梯度更新
    optimizer.step()
    model.zero_grad()

总结

对抗训练中关键的是需要找到对抗样本(尽量让模型预测出错的样本),通常是对原始的输入添加一定的扰动来构造,然后用来给模型训练.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值