对前几天的对抗损失总结一下,转载请注明出处,如有不对的地方,欢迎前来指出,一起探讨。
1.对抗损失的目的与作用
对抗损失的使用主要是为了减少标注数据,在真实的业务中,对于数据的标注是一件非常头疼的事,为了使用1000条标注能够达到2000条标注数据的所能达到效果(打个比方),模拟真实世界中各种噪声的情况,让模型更加鲁棒,更好用,准确率更高,在图像处理中经常使用引入噪声来增加图像的样本集,在文本类数据中怎么来使用呢?
那具体怎么做呢?一般方法是将真实数据+噪声,使得数据散度更大,让有标签与无标签的数据在模型上一起进行学习。
2.对抗损失的方法
- 一般对抗损失(在进行一般对抗的时候,首先对要对batch_size的数据进行一次损失的计算,根据损失对输入(可以是字向量、词向量、词性向量、偏旁向量等特征的拼接)的向量进行偏导数计算,但要注意的是需要对输入数据进行stop_gradiants,然后刚才求得的Grad进行L2正则化处理,将处理后的值称作扰动perturb,将perturb与embedding一起add,然后在进行一次损失的计算,起到了对每个batch_size的数据进行了两次的计算,增加模型的鲁棒性)
- 随机对抗损失(随机生成一个形状与embedding相同的向量,然后进行MASK操作,再进行L2正则化生成pertub,最后将生成的噪声添加到输入特征向量上进行再次计算损失)
- 虚拟对抗损失(虚拟对抗与随机损失有点相似,但是引入了KL散度,具体看下面实现)
3.对抗损失具体实现
一般对抗损失
def adversarial_loss(embedded, loss, loss_fn):
"""Adds gradient to embedding and recomputes classification loss."""
grad, = tf.gradients(
loss,
embedded)
grad = tf.stop_gr