对抗判别式领域自适应

对抗判别式领域自适应

论文链接:https://ieeexplore.ieee.org/document/8099799/

文献:E. Tzeng, J. Hoffman, K. Saenko, and T. Darrell, “Adversarial discriminative domain adaptation,” in Proceedings of the 30th IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2017, 2017, vol. 2017-July, pp. 2962–2971, doi: 10.1109/CVPR.2017.316.

领域自适应技术在不匹配说话人识别的问题中非常有效。这篇文章是图像领域的判别式对抗自适应方法,也同样可以迁移至说话人识别领域。

摘要

目的:在无监督领域自适应问题中,对抗方法是减少训练分布和测试分布之间的差别、改善泛化性能的有效又短。但如今的生成式方法在判别判别任务上性能不佳,而判别式方法,尽管能处理较大的域变换,但还未充分利用对抗生成网络的损失函数。
数据与方法:作者提出了对抗自适应方法的广义框架,进而在此基础上,提出了对抗判别式领域自适应方法 (ADDA),该方法涉及判别式模型、无共享权重和 GAN 损失。提出的方法在三个任务上进行测试:无监督领域自适应基准任务 - 数字(MNIST、USPS 和 SVHN)、跨模态的自适应学习任务(NYUD)和跨视觉域的自适应学习任务(标准 Office - amazon, webcam, dslr)。
结果:在基准任务上,当源域和目标域相近时( MNIST ⇋ USPS \text{MNIST}\leftrightharpoons \text{USPS} MNISTUSPS),ADDA 与生成式方法相当;在跨模态的自适应学习任务上,大部分类别的分类性能得到了显著改进,但也存在性能降低的类别;在跨视觉域的自适应学习任务上,不同模型的ADDA都获得了一致的性能提升。

1. 引言

深度学习在各种任务和视觉领域上能够学到各种表示,然而,领域变化/领域偏差导致这些表示在新的数据集和任务上的泛化效果不佳。针对这一问题,典型的解决方法是在针对任务的数据集上进行精调。但是,获取用于精调深度网络的大规模数据是非常困难。

领域自适应方法可以减轻领域变化的有害影响,其思想是学习到两个领域的共同特征空间。其手段是以实现最小化领域变化的距离为目标优化表示,其中领域变化的衡量方法是:

  1. maximum mean discrepancy (MMD):计算两个域均值之差的范数;
  2. correlation distances (CORAL):匹配两个分布的均值和协方差;
  3. 对抗损失:例如 GAN 方法的损失、反向梯度、领域混淆的损失。

对抗自适应方法是通过关于领域判别器的对抗优化目标来最小化的领域散度距离,常见的案例如生成对抗网络让生成器产生让判别器误导的图片。不同的对抗自适应方法的设计需要考虑三点:

  1. 是否使用生成器,
  2. 采用何种损失函数,
  3. 是否共享跨域权重。

2. 广义对抗自适应

对抗无监督自适应方法的一般框架:

  1. 源域: X s X_s Xs 及其标签 Y s Y_s Ys p s ( x , y ) p_s(x,y) ps(x,y) 获得,

  2. 目标域: X t X_t Xt p t ( x , y ) p_t(x,y) pt(x,y) 获得,但是无标签,

  3. 目标:学习目标域表示 M t M_t Mt 和目标域分类器 C t C_t Ct,使得能够在测试阶段正确地分类目标样本,即便在缺少域注释的情况下;

  4. 领域自适应方法:学习源域的表示映射 M s M_s Ms 和源域的分类器 C s C_s Cs,然后学会适应在目标域上的模型使用。

  5. 对抗自适应方法:主要目标是正则化源域和目标域映射( M s M_s Ms M t M_t Mt)的学习过程,以实现最小化经验性的源域和目标域映射( M s ( X s ) M_s(X_s) Ms(Xs) M t ( X t ) M_t(X_t) Mt(Xt))分布之间的距离。最后使得源域分类器 C s C_s Cs 能够师姐用于目标域表示,而不需要再单独为目标域训练分类器,即 C = C s = C t C = C_s = C_t C=Cs=Ct

其中分类器的监督损失:
min ⁡ M s , C   L cls ( X s , Y s ) = − E ( x s , y s ) ∼ ( X s , Y s ) ∑ k = 1 K 1 [ k = y s ] log ⁡ C ( M s ( x s ) ) \begin{aligned} &\min\limits_{M_s,C}\,\mathcal{L}_{\text{cls}}(\mathbf{X}_s,Y_s)=\\ &\quad\quad-\mathbb{E}_{(\mathbf{x}_s,y_s)\sim(\mathbf{X}_s,Y_s)}\sum\limits_{k=1}^K\mathbb{1}_{[k=y_s]}\log{C(M_s(\textbf{x}_s))} \end{aligned} Ms,CminLcls(Xs,Ys)=E(xs,ys)(Xs,Ys)k=1K1[k=ys]logC(Ms(xs))
判别器的监督损失:
L adv D ( X s , X t , M s , M t ) = − E x s ∼ X s [ log ⁡ D ( M s ( x s ) ) ] − E x t ∼ X t [ log ⁡ ( 1 − D ( M t ( x t ) ) ) ] \begin{aligned} &\mathcal{L}_{\text{adv}_D}(\mathbf{X}_s,\mathbf{X}_t,M_s,M_t)=\\ &\quad\quad-\mathbb{E}_{\mathbf{x}_s\sim\mathbf{X}_s}\left[\log{D(M_s(\textbf{x}_s))}\right]\\ &\quad\quad\quad\quad-\mathbb{E}_{\mathbf{x}_t\sim\mathbf{X}_t}\left[\log{(1-D(M_t(\textbf{x}_t)))}\right] \end{aligned}

领域自适应是迁移学习的关键子领域,目的是解决源领域和目标领域间数据分布差异,在目标领域缺乏标注数据时,通过减少两领域分布差异,让模型在目标领域有更好表现[^1]。 对抗自适应属于领域自适应的一种具体方法。从作用机制上看,领域自适应是一个较为宽泛的概念,强调的是整体上减少源领域和目标领域之间的数据分布差异这一目标。而对抗自适应是利用对抗训练的方式来达成领域自适应的目的。如相关研究中用目标领域图片训练和测试,正确率较高,但用源领域训练、目标领域测试结果较差,使用领域对抗训练(即对抗自适应方法)后正确率有明显提升,这体现了对抗自适应这种具体方法在解决领域偏移问题上的有效性,是实现领域自适应迁移学习目标的一种途径[^1][^3]。 ```python # 这里简单示意对抗自适应在代码上可能的体现 # 假设我们有一个简单的对抗训练框架 import torch import torch.nn as nn import torch.optim as optim # 定义特征提取器 class FeatureExtractor(nn.Module): def __init__(self): super(FeatureExtractor, self).__init__() # 这里简单定义一个线性层作为示例 self.fc = nn.Linear(10, 5) def forward(self, x): return self.fc(x) # 定义判别器 class DomainDiscriminator(nn.Module): def __init__(self): super(DomainDiscriminator, self).__init__() self.fc = nn.Linear(5, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): return self.sigmoid(self.fc(x)) # 初始化模型 feature_extractor = FeatureExtractor() domain_discriminator = DomainDiscriminator() # 定义优化器 optimizer_feature = optim.Adam(feature_extractor.parameters(), lr=0.001) optimizer_discriminator = optim.Adam(domain_discriminator.parameters(), lr=0.001) # 简单的训练循环 for epoch in range(10): # 模拟源领域和目标领域数据 source_data = torch.randn(100, 10) target_data = torch.randn(100, 10) # 特征提取 source_features = feature_extractor(source_data) target_features = feature_extractor(target_data) # 合并特征 all_features = torch.cat((source_features, target_features), dim=0) domain_labels = torch.cat((torch.zeros(source_features.size(0), 1), torch.ones(target_features.size(0), 1)), dim=0) # 训练判别器 optimizer_discriminator.zero_grad() domain_pred = domain_discriminator(all_features.detach()) d_loss = nn.BCELoss()(domain_pred, domain_labels) d_loss.backward() optimizer_discriminator.step() # 训练特征提取器 optimizer_feature.zero_grad() domain_pred = domain_discriminator(all_features) g_loss = -nn.BCELoss()(domain_pred, domain_labels) g_loss.backward() optimizer_feature.step() ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值