跨模态转移旨在利用大型预训练模型来完成可能不属于预训练数据模态的任务。现有的研究在将经典微调扩展到跨模态场景方面取得了一定的成功,但仍然缺乏对模态差距对转移的影响的理解。在这项工作中,进行了一系列关于转移过程中源表示质量的实验,揭示了更大的模态差距与较少知识重用之间的联系,这意味着转移效果不佳。然后,使用条件分布 P ( Y ∣ X ) P(Y|X) P(Y∣X) 将这种差距形式化为不同模态之间的知识不对齐。针对这个问题,论文提出了“模态知识对齐”(
MoNA
)方法,这是一种元学习方法,旨在学习目标数据变换,以减少转移前的模态知识差异。实验证明,论文的方法实现了更好地重用跨模态转移中的源模态知识,从而对现有微调方法的改进。来源:晓飞的算法工程笔记 公众号,转载请注明出处
论文: Learning Modality Knowledge Alignment for Cross-Modality Transfer
Introduction
将来自过去经验的知识转移至新任务是人类智能的基本能力。在机器学习社区中,不断追求这种获取和重用知识的能力,旨在构建能够更准确预测并更高效学习数据的人工智能系统。如今,由于在大量数据上进行训练的大型基础模型已广泛可用,因此使用这样的预训练模型作为新任务的强大特征提取器已成为迁移学习的常见做法。自然地,预训练模型和下游任务来自同一模态,例如,在ImageNet
上预训练的视觉Transformer
模型和CIFAR-100
分类任务。然而,最近的研究尝试扩展到跨模态转移,例如使用视觉Transformer
进行音频分类,并对表格数据进行微调语言模型。
这种跨模态转移的动机易于理解,特别是当目标模态数据稀缺时。科学任务,如心电图分类和蛋白质距离预测,在收集大量训练数据方面存在困难,并且进一步需要人类专家进行昂贵的注释成本。在这种情况下,利用其他数据更容易收集的模态(如视觉和语言)的预训练模型,来帮助目标模式任务是可取的。然而,由于两个挑战,跨模态转移并不像模态内转移那样直接:1
)跨模态的输入空间和标签空间不同,2
)解决不同模态任务所需的知识也可能不同。
先前的研究通过设计模态特定的嵌入器和预测器来应对第一个挑战,以便从输入到输出与预训练模型进行接口。然而,第二个挑战尚未得到很好解决。一些方法认为,大型预训练模型可以作为通用编码器,并在微调过程中冻结预训练模型。其他方法则同时微调预训练模型和模态特定组件。这两种方法实证表明,预训练模型可以迁移到其他模态。然而,源模态中哪些知识通过预训练模型进行了转移,以及这些知识如何有利于目标模态,仍然是一个未解决的核心问题。例如,ORCA
观察到,在某些目标模态任务上从头开始训练模型甚至比对预训练模型进行普通微调更好。这表明如果不恰当地转移,预训练模型中包含的知识可能不会提高目标性能。
在这项工作中,论文深入探讨了跨模态转移的第二个挑战。首先进行实验,研究目标模态微调如何影响源模态数据的表示质量。论文观察到,在一些目标模态任务上微调预训练的Swin Transformer
可以帮助Swin
编码器提取更具有区分性的图像特征,而在其他模态上微调则会削弱这种能力。这一实证观察表明,在不同程度上,不同模态之间可能存在知识面,称之为模态语义知识,这些知识会影响跨模态转移的有效性。
为了明确模态之间差异的这一方面,将模态语义知识解释为条件概率分布 P ( Y ∣ X ) P(Y|X) P(Y∣X) 。根据目标模态的任务修改源模态的条件分布,使两者可以进行比较。因此,能够将模态知识差异形式化为源模态和目标模态的条件分布之间的差异。当目标条件分布与修改后的源条件分布相似时,称模态语义知识是对齐的,预训练模型学习的源区分功能可以被重用于目标模态。相反,模态语义知识相互矛盾,可能没有相互促进的作用,这解释了ORCA
中的观察。
论文的解释为理解先前跨模态转移作品提出的两阶段调优流程的有效性提供了新的视角:将第一阶段视为针对目标模态的隐式数据转换学习,从而使转换后数据的条件分布与源数据更加对齐。因此,这启示可以在微调之前直接学习一个适当的目标嵌入函数,有助于最小化知识不对齐。基于此,论文提出了一种新方法MoNA
,通过两阶段训练改善跨模态转移。在第一阶段,MoNA
利用元学习来学习一个最优的目标嵌入器,在全面微调时作为初始化之一,结合预训练权重,实现在全面微调过程中最大程度地重用源模态知识。在第二阶段,利用学习到的目标嵌入器作为起点,沿用传统微调方法,更新所有参数以适应目标任务,同时最大程度地利用源知识。
论文在两个跨模态转移基准数据集NAS-Bench-360
和PDEBench
上进行了大量实验,以验证假设和提出的方法的有效性。这两个基准数据集都集中在与科学问题相关的模态上,其中训练数据的稀缺性尤为严重。对MoNA
与先前方法进行了比较,实验结果表明该方法表现出色。
Problem Formulation and Analysis
Introduction to basic notations and architecture
考虑源模态 M s \mathcal{M}^s Ms 和目标模态 M t \mathcal{M}^t Mt 之间的知识转移。源模态中的数据(如视觉或语言数据)更容易获取且成本更低,同时大型预训练模型也是公开可用的。相反,目标模态数据不足以预训练自己的大型模型。这两个模态在输入空间和标签空间上均存在差异,即 X s ≠ X t \mathcal{X}^s\neq\mathcal{X}^t Xs=Xt , Y s ≠ Y t \mathcal{Y}^s\neq\mathcal{Y}^t Ys=Yt 。跨模态转移旨在利用源预训练模型(由参数 θ S \boldsymbol{\theta}^{\mathcal{S}} θS 参数化)来帮助处理目标任务,该目标任务仅具有一小组带标签数据 { x i t , y i t } i = 1 n t \{\boldsymbol{x}_i^{t}, y_i^{t}\}_{i=1}^{n_t} { xit,yit}i=1nt 。
根据先前的研究,模型结构 g θ g_{\boldsymbol{\theta}} gθ 包括一个嵌入器 e ( ⋅ ; θ e ) e(\cdot;\boldsymbol{\theta}_e) e(⋅;θe) ,一个Transformer
编码器 f ( ⋅ ; θ f ) f(\cdot;\boldsymbol{\theta}_f) f(⋅;θf) 和一个预测器 h ( ⋅ ; θ h ) h(\cdot;\boldsymbol{\theta}_h) h(⋅;θh) ,整个模型的参数表示为 θ = { θ e , θ f , θ h } \boldsymbol{\theta} = \{\boldsymbol{\theta}_e, \boldsymbol{\theta}_f, \boldsymbol{\theta}_h\} θ={
θe,θf,θh} 。特别地,预训练的Transformer
具有自己的嵌入器和预测器,因此将源模型的预训练权重表示为 θ 0 S = { θ e 0 S , θ f 0 S , θ h 0 S } \boldsymbol{\theta}^{\mathcal{S}}_0 = \{\boldsymbol{\theta}_{e_0}^\mathcal{S}, \boldsymbol{\theta}_{f_0}^\mathcal{S}, \boldsymbol{\theta}_{h_0}^\mathcal{S}\} θ0S={
θe0S,θf0S,θh0S} 。
嵌入器将输入数据映射到共享的输入嵌入空间 X ^ \hat{\mathcal{X}} X^ ,编码器从嵌入的输入中提取特征。预测器是一个线性层,将编码器的输出映射到标签空间上。对于目标模型 g θ T : θ T = { θ e T , θ f T , θ h T } g_{\boldsymbol{\theta}^\mathcal{T}}: \boldsymbol{\theta}^\mathcal{T} = \{\boldsymbol{\theta}^\mathcal{T}_e, \boldsymbol{\theta}^\mathcal{T}_f, \boldsymbol{\theta}^\mathcal{T}_h\} gθT:θT={ θeT,θfT,θhT} ,嵌入器和预测器均经过特定重新设计以适应目标任务的输入和标签空间,同时使用 θ f 0 S \boldsymbol{\theta}_{f_0}^\mathcal{S} θf0S 来初始化编码器权重 θ f T \boldsymbol{\theta}^\mathcal{T}_f θfT 。
这种架构的灵活性使得能够在目标任务上进行端到端的训练,简单地通过在给定训练数据集上最小化特定任务损失来微调目标模型的所有参数:
KaTeX parse error: Undefined control sequence: \label at position 196: … y_i^t\right), \̲l̲a̲b̲e̲l̲{eq:vanilla} \e…
这里的 ℓ \ell ℓ 是任务损失函数,例如交叉熵。
通过这种方式直接从目标监督中学习,鼓励模型学习有助于区分目标数据的知识。由于预训练模型已经包含源领域的辨别知识,因此跨模态转移自然期望源领域和目标领域的知识在某些方面相似,以便源领域的知识可以被重用来促进目标学习。接下来,1
)进行实验表明这种相似性取决于模态,2
)提供模态知识的解释并形式化知识差异。
特定于模态的嵌入器和预测器的实现完全按照ORCA
中的设计进行。
特定于模态的嵌入器的结构取决于任务是2D
还是1D
。
-
对于
2D
任务,嵌入器由线性投影层和LayerNorm
操作组成。对于任何大小为 C × H × W C\times H\times W