文献地址:
https://arxiv.org/pdf/1605.07725.pdf
摘要:
"对抗训练"提供了一种正则化的监督学习算法.
"虚拟对抗训练"能够将监督学习算法拓展到半监督学习的环境中.
两者都是对输入向量进行较小的扰动,不适用于稀疏高维的输入向量.
从而提升泛化能力.
简介:
先前的对抗训练主要应用与图像分类,本文中要拓展到文本分类任务和序列模型.
本文的对抗训练就是对输入相连进行小的修改.
模型:
如上图左边是一个常规的双向LSTM,输入的稀疏向量w 被转化为连续向量v时候,输入到LSTM模型中,然后用模型的输出做文本分类,而右图则是对抗生成网络的输入.
其中输入相连V需要归一化之后在加扰动项,归一化的公式如下:
对抗和虚拟对抗训练:
对抗训练:
将Θ作为参数,x作为输入,损失函数如下:
其中r是输入的扰动.
r(adv) 是固定Θ的情况下(为了反向传播时不会通过这里),最小化p(y|x+r)概率时的r值.
即给一个扰动使得结果最差
然后用训练的原始数据去加上这个扰动,并调整参数Θ(从而提升模型的鲁棒性)
虚拟对抗训练:
损失函数如下:
KL散度是计算两个分布的差异,如上图公式,先找出固定参数Θ时差异最大的扰动项,然后去更新Θ以最小化这个加了扰动项的特征相连和原始特征相连的差异.是分类器对当前(此时的Θ决定的模型)最敏感的方向上具有抗干扰性.
这种虚拟对抗训练的方式不需要真实的label,这使得其可以应用到半监督学习上去.
本文对抗训练:
在本文中需要把输入的向量做归一化处理并定义为s,模型的条件概率就是p(y|s;Θ),Θ是模型参数.
定义扰动项r(adv)如下:
对抗部分的损失为每个样本损失的平均值:
本文虚拟对抗训练:
虚拟对抗模型的扰动项:
虚拟对抗部分的损失函数:
实验设计:
LSTM + ADV
结果:
e=5.0
从上图看到对抗网络具有较好的鲁棒性,对最终的性能可能带来2-3%的提升(收敛速度很慢)
上图所示,虚拟对抗训练的损失最小
上图是和"good"和"bad"最接近的词语(余弦最小)
可以看到"good"和"bad"有着相似的语法作用(都是修饰某个东西),但是其含义是不同的.
而对抗训练就能区分开这种语法上相似,但含以上不相似的特征