1. 领域自适应网络概述
有时候我们在某个领域训练出的一个模型,想迁移到另一个领域,这样我们就不需要每个领域都去标注大量的数据了。但是这两个领域的数据分布是有些差异,要如何办呢?比如我们在黑白图片上训练出了数字的识别模型,但是我们希望该模型用到彩色数字的识别上。如果直接将模型迁移过去,结果并不理想,在黑白图片上的测试准确率可以达到99.5%,但迁移到彩色图片上时测试准确率仅57.5%。这是由于黑白图片和彩色图片的分布不一样,那么如何处理这种情况呢?我们需要用到 “领域自适应(Domain Adaption)”。

2. 领域自适应的基本思路
我们将有标签的、能够训练模型的领域称为源域(Source Domain),无标签或者只有少量标签的领域称为目标域(Target Domain),我们的目的是要将源域训练得到的模型迁移到目标域。但是源域数据和目标域数据的分布不一样,领域自适应的基本思路是设计一个特征提取器,使得从源域和目标域提取的特征分布是一样的,如下图所示。

我们将数字识别任务分成特征提取器和标签预测器两个部分,特征提取器有若干层网络负责提取图片的特征,并输出一个向量,然后将这个向量交给标签预测器,标签预测器也有若干层,负责根据特征提取输出的向量预测图片所展示的数字。现在有一堆的源域数据,它们是有标签的,还有一堆目标域数据,它们是没有标签的。假设源域数据输入到特征提取器输出的结果是蓝色的点,而目标域数据输入到特征提取器输出的结果是红色的点。我们要训练这个特征提取器,尽可能地使蓝色的点和红色地点混在一起分不出差异。那么要如何训练这个特征提取器呢?

3. 域分类器的引入
除了上述的特征提取器和标签预测器外,我们再引入一个新的网络,叫做 “域分类器”。域分类器的任务就是负责鉴别特征提取器输出的特征是来自源域数据还是目标域数据。而特征提取器要尽量减小源域数据数据和目标域数据输出向量的差异,以骗过域分类器,使其无法正确的鉴别。说到这里,大家应该想到了前面的生成对抗网络(GAN),这两者确实很像。这里的特征提取器就相当于GAN的生成器,而这里的域分类器就相当于GAN的鉴别器。那么有一个问题,特征提取器似乎明显要占优势,比如一个极端的情况,特征提取器无论是何输入都输出0,那么鉴别器无论如何也无法鉴别出哪些来自源域,哪些来自目标域。但是这种情况其实是不会出现的。因为预测器也需要特征提取器输出的这个向量,用来判断图片的标签是什么。如果特征提取器无论输入是什么都输出0的话,那么预测器是无法根据这个向量来预测图片标签的。

下面再利用数学符号把上述过程再理一下。首先,我们假设三个网络的参数分别为:特征提取器—— θ f \theta_f θf,标签预测器—— θ p \theta_p θp,域分类器—— θ d \theta_d

本文介绍了领域自适应的概念,旨在解决不同领域数据分布差异导致的模型迁移问题。通过引入特征提取器、标签预测器和域分类器,使得模型能在源域(如黑白图片)训练后适应目标域(如彩色图片)。特征提取器尝试混淆源域和目标域的特征分布,而域分类器则试图区分两者,形成类似于生成对抗网络的训练过程。最终目标是通过训练使模型在目标域上保持高准确性。
最低0.47元/天 解锁文章
8145





