【论文阅读】Cross Attention Network for Few-shot Classification

在这里插入图片描述
用于小样本分类的交叉注意力网络
引用:Hou, Ruibing, et al. “Cross attention network for few-shot classification.” Advances in neural information processing systems 32 (2019).
论文地址:下载地址
论文代码:https://github.com/blue-blue272/fewshot-CAN

Abstract

  少样本分类旨在仅给出少量标记样本的情况下,识别来自未见过类别的未标记样本。未见类别和数据不足的问题使得少样本分类极具挑战性。许多现有的方法分别从标记和未标记样本中提取特征,导致特征的区分性不足。在本工作中,我们提出了一种新颖的交叉注意力网络,以解决少样本分类中的挑战性问题。首先,提出了交叉注意力模块来应对未见类别的问题。该模块为每对类别特征和查询样本特征生成交叉注意力图,以突出目标对象区域,从而使提取的特征更具区分性。其次,我们提出了一种推理算法,以缓解数据不足的问题。该算法通过迭代地利用未标记的查询集来增强支持集,从而使类别特征更加具有代表性。在两个基准数据集上的大量实验表明,我们的方法是简单、高效且计算成本低的框架,并且优于当前最先进的方法。

1 Introduction

  少样本分类旨在将未标记样本(查询集)分类到未见过的类别中,前提是只提供了极少量标记样本(支持集)。与传统分类相比,少样本分类面临两个主要挑战:一是未见过的类别,即训练和测试类别之间不重叠;二是数据不足的问题,即测试的未见类别只有极少量的标记样本。解决少样本分类问题需要模型在训练阶段的已知类别中进行训练,并能够在仅有少量标记样本的情况下对未见类别进行良好的泛化。

  一种直接的方法是使用来自未见类别的少量标记样本对预训练模型进行微调,但这可能会导致严重的过拟合。正则化和数据增强能够缓解但不能完全解决过拟合问题。最近,元学习范式 1 2 3 被广泛用于少样本学习。在元学习中,可迁移的元知识,例如一种优化策略 4 5,一种良好的初始条件 6 7 8,或一个度量空间 9 10 11,是从一组训练任务中提取的,并且能够泛化到新的测试任务。训练阶段的任务通常模拟测试阶段的设置,以缩小训练和测试设置之间的差距,并增强模型的泛化能力。

  尽管已有一些很有前景的方法,但很少有方法充分关注提取特征的可区分性。它们通常独立地从支持类和未标记的查询样本中提取特征,导致这些特征的区分性不足。一方面,支持/查询集中的测试图像来自未见过的类别,因此它们的特征难以关注目标对象。具体来说,对于包含多个对象的测试图像,所提取的特征可能会关注训练集中具有大量标记样本的已见类别中的对象,而忽略来自未见类别的目标对象。如图1(c)和(d)所示,对于来自测试类别“窗帘”的两张图像,提取的特征仅捕获了与训练类别相关的对象信息,如图1(a)和(b)中的“人”或“椅子”。另一方面,数据不足的问题使得每个测试类别的特征不能代表真实的类别分布,因为它仅从少量标记的支持样本中获得。总而言之,独立的特征表示在少样本分类中可能会失效。
在这里插入图片描述
图1所示。现有方法9和所提出方法的训练和测试图像的类激活图12的一个例子可以。颜色越暖值越高。

  在本工作中,我们提出了一种新颖的交叉注意力网络(Cross Attention Network, CAN)来增强少样本分类中特征的可区分性。首先,引入了交叉注意力模块(Cross Attention Module, CAM)来解决未见类别的问题。交叉注意力的思想来源于人类的少样本学习行为。给定少量标记样本来识别来自未见类别的样本时,人类倾向于首先定位标记和未标记样本对中最相关的区域。类似地,给定类别特征图和查询样本特征图,CAM 为每个特征生成一个交叉注意力图,以突出目标对象。相关性估计和元融合被用于实现这一目标。这样一来,测试样本中的目标对象可以得到关注,并且通过交叉注意力图加权的特征更具区分性。如图1(e)所示,带有 CAM 的提取特征能够大致定位目标对象“窗帘”的区域。

  其次,我们提出了一种推理算法,利用整个未标记的查询集来缓解数据不足的问题。所提出的算法通过迭代地预测查询样本的标签,并选择伪标记的查询样本来增强支持集。每个类别中更多的支持样本能够使得获得的类别特征更加具有代表性,从而缓解数据不足的问题。

  在多个基准数据集上进行了实验,将所提出的 CAN 与现有的少样本元学习方法进行了比较。我们的方法在所有数据集上都达到了新的最先进的结果,证明了 CAN 的有效性。

2 Related Work

2.1 少样本分类

  根据整个未标记查询集的可用性,少样本分类可分为两类:归纳式少样本分类和传导式少样本分类。在本工作中,我们主要探讨基于元学习的少样本方法。

2.2 归纳式少样本学习

  归纳式少样本学习在近年来得到了广泛的研究。一种有前途的方法是元学习 1 2 3 范式。通常,它从一组任务中训练一个元学习器,以提取可迁移到新任务的元知识,使其能够应对稀缺数据的情况。用于少样本分类的元学习方法大致可以分为三类:基于优化的方法将元学习器设计为优化器,用于学习更新模型参数 5 4 13。此外,研究 6 14 15 学习了一个良好的初始条件,使学习器能够在少数优化步骤内快速适应新任务。基于参数生成的方法 16 17 18 19 通常将元学习器设计为一个参数预测网络。基于度量学习的方法 10 9 11 20 21 学习了一个通用特征空间,在该空间中基于距离度量可区分类别。
  我们提出的框架属于基于度量学习的方法。与现有的度量学习方法独立提取支持样本和查询样本特征不同,我们的方法利用支持和查询特征之间的语义相关性来突出目标对象。尽管基于参数生成的方法也考虑了支持和查询样本之间的关系,但这些方法需要额外的复杂参数预测网络。我们的方法在较少的计算负担下,比这些方法取得了更大的提升。

2.3 传导式算法

传导式少样本分类首先由 22 引入,该方法在支持集和整个查询集上构建了一个图,并在图中传播标签。然而,该方法需要一个特定的网络结构,通用性较差。受半监督学习中的自训练策略启发 23 24 25 26,我们提出了一种更简单且通用性更强的传导式少样本算法,该算法通过明确地使用未标记的查询样本来增强标记支持集,以获得更具代表性的类别特征。此外,所提出的传导式算法可以直接应用于现有的模型,例如 Prototypical Network 9、Matching Network 10 和 Relation Network 11

2.4 注意力模型

  注意力机制旨在突出重要的局部区域,以提取更具区分性的特征。在计算机视觉应用中取得了巨大的成功,如图像分类 27 28 29、图像描述生成 30 31 32 和视觉问答 33 34 35。在图像分类中,SENet 27 提出了一种通道注意力模块,以增强网络的表示能力。Woo 等人 28 29 进一步将通道和空间注意力模块集成到一个模块中。然而,这些模块对于少样本分类并不有效。我们认为它们仅基于训练类别的先验来定位测试图像中的重要区域,无法泛化到来自未见类别的测试图像。例如,如图1所示,由于“窗帘”不属于训练类别,上述模块将关注于已见类别的前景对象,如“人”或“椅子”,而不是“窗帘”。
  对于少样本图像分类,在本文中,我们设计了一个元学习器来计算支持(或类别)和查询特征图之间的交叉注意力,以帮助定位目标对象的关键区域,并增强特征的区分性。

3 Cross Attention Module

3.1 问题定义

  少样本分类通常涉及训练集、支持集和查询集。训练集包含大量类别和标记样本。少样本的支持集和未标记样本的查询集共享相同的标签空间,并且与训练集的标签空间不重叠。少样本分类的目标是在给定训练集和支持集的情况下对未标记的查询样本进行分类。如果支持集包含 C C C 个类别且每个类别包含 K K K 个标记样本,那么该目标少样本问题称为 C C C-way K K K-shot。

  参考文献 10 9 36 37 38 39 6,我们采用了“episode”训练机制,该机制被证明是少样本学习中的一种有效方法。训练过程中使用的每个“episode”模拟了测试时的设置。每个“episode”通过随机采样 C C C 个类别以及每个类别 K K K 个标记样本组成支持集 S = { ( x a s , y a s ) } a = 1 n s S = \{(x_{a}^s, y_{a}^s)\}_{a=1}^{n_s} S={(xas,yas)}a=1ns(其中 n s = C × K n_s = C \times K ns=C×K),并从这 C C C 个类别中抽取一部分剩余样本组成查询集 Q = { ( x b q , y b q ) } b = 1 n q Q = \{(x_{b}^q, y_{b}^q)\}_{b=1}^{n_q} Q={(xbq,ybq)}b=1nq。我们将第 k k k 类的支持子集记作 S k S_k Sk。如何表示每个支持类别 S k S_k Sk 和查询样本 x b q x_{b}^q xbq,以及测量它们之间的相似性,是少样本分类的关键问题。

3.2 CAM 概述

  在本工作中,我们借助度量学习来获得支持类和查询样本对的合适特征表示。与现有独立提取类别和查询特征的方法不同,我们提出了交叉注意力模块(Cross Attention Module, CAM),它可以对类别特征和查询特征之间的语义相关性进行建模,从而突出目标对象,并有助于后续的匹配。

  CAM 如图 2 所示。类别特征图 P k ∈ R c × h × w P_k \in \mathbb{R}^{c \times h \times w} PkRc×h×w 从支持集 S k S_k Sk 中提取( k ∈ { 1 , 2 , … , C } k \in \{1, 2, \ldots, C\} k{ 1,2,,C}),查询特征图 Q b ∈ R c × h × w Q_b \in \mathbb{R}^{c \times h \times w} QbRc×h×w 从查询样本 x b q x_b^q xbq 中提取( b ∈ { 1 , 2 , … , n q } b \in \{1, 2, \ldots, n_q\} b{ 1,2,,nq}),其中 c c c h h h w w w 分别表示特征图的通道数、高度和宽度。CAM 为 P k P_k Pk Q b Q_b Qb)生成交叉注意力图 A p A_p Ap A q A_q Aq),并用其加权特征图以获得更具区分性的特征表示 P ˉ k b \bar{P}_k^b Pˉkb Q ˉ b k \bar{Q}_b^k Qˉbk)。为了简化,我们省略上下标,将输入的类别和查询特征图分别记作 P P P Q Q Q,输出的类别和查询特征图分别记作 P ˉ \bar{P} Pˉ Q ˉ \bar{Q} Qˉ

在这里插入图片描述
图2. (a) 交叉注意力模块(CAM)。(b) CAM 中的融合层。在图中, R p R_p Rp ( R q R_q Rq) ∈ R m × m \in \mathbb{R}^{m \times m} Rm×m 被重塑为 R m × h × w \mathbb{R}^{m \times h \times w} Rm×h×w 以便于更好的可视化。如图所示,CAM 能够生成关注目标对象区域(图中的“被涂层的寻回犬”)的特征图。

3.3 相关性层

  如图 2 所示,我们首先设计了一个相关性层来计算 P P P Q Q Q 之间的相关性图,该图用于指导生成交叉注意力图。为此,我们首先将 P P P Q Q Q 重塑为 R c × m \mathbb{R}^{c \times m} Rc×m,即 P = [ p 1 , p 2 , … , p m ] P = [p_1, p_2, \ldots, p_m] P=[p1,p2,,pm] Q = [ q 1 , q 2 , … , q m ] Q = [q_1, q_2, \ldots, q_m] Q=[q1,q2,,qm],其中 m m m m = h × w m = h \times w m=h×w)是每个特征图的空间位置数。 p i p_i pi q i ∈ R c q_i \in \mathbb{R}^c qiRc 分别表示 P P P Q Q Q 中第 i i i 个空间位置的特征向量。相关性层使用余弦距离计算 { p i } i = 1 m \{p_i\}_{i=1}^m { pi}i=1m { q i } i = 1 m \{q_i\}_{i=1}^m { qi}i=1m 之间的语义相关性,以得到相关性图 R ∈ R m × m R \in \mathbb{R}^{m \times m} RR

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值