Domain Adaptation(李宏毅)机器学习 2023 Spring HW11 (Boss Baseline)

1. 领域适配简介

领域适配是一种迁移学习方法,适用于源领域和目标领域数据分布不同但学习任务相同的情况。具体而言,我们在源领域(通常有大量标注数据)训练一个模型,并希望将其应用于目标领域(通常只有少量或没有标注数据)。然而,由于这两个领域的数据分布不同,模型在目标领域上的性能可能会显著下降。领域适配技术的目标是通过对模型进行适配,缩小源领域与目标领域之间的差距,从而提升模型在目标领域的表现。

Domain Shift (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

以数字识别为例,如果我们的源数据是灰度图像,并且在这些数据上训练模型,我们可以预期模型会取得相当不错的效果。然而,如果我们将这个在灰度图像上训练的模型用于分类彩色图像,模型的表现可能会较差。这是因为这两个数据集之间存在领域转移。

领域适配方法可以根据目标领域中标签的可用性进行分类:

  1. 有监督领域适配:源领域和目标领域都有标注数据。这种情况较为少见,因为领域适配的主要动机是目标领域标签的稀缺性。

  2. 无监督领域适配:源领域有标注数据,而目标领域没有标注数据。这是最常见且最具挑战性的情况。

  3. 半监督领域适配:源领域有标注数据,目标领域则只有少量标注数据。

Different Domain Adaptation Scenarios (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

我们的博客和作业主要关注目标领域缺乏标注数据的场景。

解决这个问题的基本概念如下:我们旨在找到一个特征提取器,它能够接收输入数据并输出特征空间。这个特征提取器应该能够滤除领域特定的变化,同时保留不同领域之间共享的特征。例如,在以下的示例中,特征提取器应该能够忽略图像的颜色,对于相同的数字,不论其颜色如何,都能生成具有相同分布的特征。

Basic Idea of Domain Adaptation (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

研究人员提出了许多方法,其中对抗学习方法是最常见且最有效的技术之一。

Domain Adversarial Training - 1 (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

我们将一个标准网络分为两部分:特征提取器和标签预测器。在训练过程中,我们以标准的有监督方式在源领域数据上训练整个网络。对于目标领域数据,我们只使用特征提取器提取特征,并采用技术手段将目标领域的特征与源领域的特征对齐。

Domain Adversarial Training - 2 (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

具体来说,我们设计了一个新的领域分类器,它是一个二分类器,输入特征向量并判断输入数据是来自源领域还是目标领域。另一方面,特征生成器的设计目的是“欺骗”领域分类器,使其无法正确区分来源领域。

Domain Adversarial Training - 3 (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

如果我们仔细思考上述方法,我们可以直观地理解,尽管对抗训练可以使源领域和目标领域的整体分布更加相似,如下图左侧所示,但这种分布可能并不适合或不适用于机器学习任务。理想情况下,我们期望获得右侧图像所示的分布。

Limitation of Domain Adversarial Training (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

当然,已有大量论文提出了针对这一问题的解决方法。为了在这次作业中通过strong 和 boss baseline,我们需要深入相关文献,并采用合适的方法。在作业中,我将介绍更多相关的论文和技术。

2. Homework Results and Analysis

作业 11 聚焦于领域适配。给定真实图像(带标签)和涂鸦(无标签),任务是利用领域适配技术训练一个网络,能够准确预测绘制图像的标签。

task description (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

数据集设置:

  • 标签:10个类别(编号从0到9),如以下图片所示。

  • 训练集:5000张 (32, 32) RGB 真实图像(带标签)。

  • 测试集:100000张 (28, 28) 灰度绘制图像。

source and target data (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

baseline 的门槛 在 Kaggle 上的数值为:

Baseline

Public

Private

Simple

Score >= 0.44280

Score >= 0.44012

Medium

Score >= 0.65994

Score >= 0.65928

Strong

Score >= 0.75342

Score >= 0.75518

Boss

Score >= 0.81072

Score >= 0.80794

像往常一样,助教会提供关于如何超越各种基准模型的指导。

Hints for Simple, Medium and Strong Baseline (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

Hints for Boss Baseline (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

2.1 Simple Baseline

使用助教提供的默认代码足以通过 simple baseline。

2.2 Medium Baseline

通过增加训练轮数并调整超参数 lambda,可以通过 medium baseline。

num_epochs = 800
# train 800 epochs

with Progress(TextColumn("[progress.description]{task.description}"),
              BarColumn(),
              TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
              TimeRemainingColumn(),
              TimeElapsedColumn()) as progress:
    epoch_tqdm = progress.add_task(description="epoch progress", total=num_epochs)
    for epoch in range(num_epochs):
        train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, progress, lamb=0.6)

        progress.advance(epoch_tqdm, advance=1)
        if epoch == 10:
          torch.save(feature_extractor.state_dict(), f'extractor_model_early.bin')
          torch.save(label_predictor.state_dict(), f'predictor_model_early.bin')
        elif epoch == 100:
          torch.save(feature_extractor.state_dict(), f'extractor_model_mid.bin')
          torch.save(label_predictor.state_dict(), f'predictor_model_mid.bin')

        torch.save(feature_extractor.state_dict(), f'extractor_model.bin')
        torch.save(label_predictor.state_dict(), f'predictor_model.bin')
        print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))

2.3 Strong Baseline

助教建议了几篇论文来提升性能并通过strong baseline。其中,我发现以下这篇论文特别有趣:《Minimum Class Confusion for Versatile Domain Adaptation》(Jin, Ying, et al.)(链接)。

他们“提出了一种新颖的损失函数:Minimum Class Confusion(MCC)。它可以被描述为一种新颖且多功能的领域适配方法,无需显式进行领域对齐,且具有较快的收敛速度。此外,它还可以作为一种通用正则化器,与现有的领域适配方法正交且互补,从而进一步加速和改善这些已有的竞争性方法。”(Jin, Ying, et al.,p. 3)

The schematic of the Minimum Class Confusion (MCC) loss function (source: https://arxiv.org/abs/1912.03699)

MCC 的计算过程如下:

给定以下变量:

  • \mathbf{f}_t:网络输出的目标领域数据的logits(即网络分类器的输出)。

  • T :一个温度参数,用于缩放logits,使其更加平滑并增大类别分布之间的差异。

  • \mathbf{p}_t:目标领域经温度平滑后的预测结果,表示通过softmax得到的概率分布。

  • H(\cdot):熵函数,用于衡量每个样本的预测不确定性。

MCC步骤1:目标领域logits的温度缩放:

目标领域的logits ​ \mathbf{f}_t 通过温度进行缩放,以平滑分类概率:

\\ \mathbf{f}_t' = \frac{\mathbf{f}_t}{T} \\

其中, T > 1 用于拉伸预测的概率分布,防止模型过于自信。

MCC步骤2:计算Softmax输出:

将经过温度缩放的logits通过softmax函数得到目标领域预测的概率分布 \mathbf{p}_t ​:

\mathbf{p}_t = \text{Softmax}(\mathbf{f}_t') \\

此处, \mathbf{p}_t ​是一个 N \times C 的矩阵,其中 N 是目标领域样本的数量,C 是分类的类别数。

MCC步骤3:计算样本熵权重:

每个样本的熵 H(\mathbf{p}_t) 使用以下公式计算:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值