文章目录
前言
对抗训练无论是在CV领域还是在NLP领域都具有举足轻重的地位,在NLP比赛中,抗训练确确实实能够提升模型在具体任务上的泛化性能。
一、什么是对抗训练?
对抗样本
:对输入增加微小扰动得到的样本。旨在增加模型损失。
对抗训练
:训练模型去区分样例是真实样例还是对抗样本的过程。对抗训练不仅可以提升模型对对抗样本的防御能力,还能提升对原始样本的泛化能力。
1、FGM——Fast Gradient Method
FSGM是每个方向上都走相同的一步,2017年Goodfellow后续提出的FGM则是根据具体的梯度进行scale,得到更好的对抗样本:
对于每个x:
1.计算x的前向loss、反向传播得到梯度
2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r
3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上
4.将embedding恢复为(1)时的值
5.根据(3)的梯度对参数进行更新
Pytorch实现
class FGM():
""" 快速梯度对抗训练
"""
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emb_name='word_embeddings'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='word_embeddings'):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
一、使用步骤
for step, batch in enumerate(train_dataloader):# 遍历批数据
# Add batch to GPU
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from our dataloader
# 每一批数据展开
# train_inputs.extend(one_freq_input_ids)
# train_labels.extend(one_freq_labels)
# train_masks.extend(one_freq_attention_masks)
# train_token_types.extend(one_freq_token_types)
# 接收batch的输入
b_input_ids, b_input_mask, b_labels, b_token_types = batch
outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
logits = outputs[0]
loss_func = BCEWithLogitsLoss() # 计算损失
loss = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels))
train_loss_set.append(loss.item())# 记录loss
# Backward pass
loss.backward(retain_graph=True) # loss反向求导
#对抗训练
fgm.attack()
loss_adv = loss_func(logits.view(-1,num_labels),b_labels.type_as(logits).view(-1,num_labels))
loss_adv.backward(retain_graph=True)
fgm.restore()
#梯度更新
optimizer.step()
model.zero_grad()
总结
对抗训练中关键的是需要找到对抗样本(尽量让模型预测出错的样本),通常是对原始的输入添加一定的扰动来构造,然后用来给模型训练.