【论文解读】前向可兼容的少样本增量学习


Forward Compatible Few-Shot Class-Incremental Learning - CVPR2022原文链接

本文关注的问题是少样本类增量学习(Few Shot Class Incremetal Learning, FSCIL)当前主流的方法基本上都是在学习新类的同时不忘记旧类,侧重于对于旧类的遗忘问题,本文中称这种为后向兼容性。反之,本文在此基础上提出可以前瞻性的学习新类以应对未来的模型更新,称为前向兼容性。主要是通过虚拟原型压缩已知类的嵌入,同时预测可能出现的新类并为模型更新做准备。

1.背景

近些年,类增量学习完成的工作不在是离线全部训练好模型再进行推理,而是通过对着类别的增加增量的完成学习过程,而当前学习阶段只能使用当前的数据,因此对于当前模型的学习往往会忘记之前学习到的知识。目前大多数工作关注的都是在学习新类的时候尽量不遗忘旧知识,注重的是对于旧知识的兼容性的平衡,本文在此基础上提出前向兼容性可以提前预知可能出现的新类。

2.相关工作

少样本学习

少样本学习即通过较少的样本进行学习以优化模型,目前主要是三种方式:基于权重重要性(防止重要权重发生改变)、基于知识蒸馏(knowledge distillation)、利用之前的实例重新训练(rehearsal)

少样本类增量学习

主要介绍了几个SOTA方法,TOPIC通过神经气体网络保存特征的拓扑结构从而防止遗忘,语义注意知识蒸馏将词嵌入作为辅助信息,FSLL增量阶段选择参数更新,CEC为了适应增量过程,在分类器间利用额外的图模型进行上下文信息传播

向后兼容学习

软件更新过程中考虑向后兼容的问题,即新版本兼容旧版本,而前向兼容预示着可以兼容未来的版本,大多数动作都是集中于向后兼容学习,没有考虑到前向兼容性

3.从旧类到新类

3.1 FSCIL

基础阶段

数据集中得到一部分类作为基类,分为训练集和测试集,对于训练集进行正常的训练即可,然后在测试集上进行测试。

增量阶段

剩下的类别分为几组,每组类别不相交,每个增量阶段采用 N-way M shot的实验方式,即每个增量阶段使用的是N类样本M个样本的方式。剩下类的训练集完成这样的训练过程后,在所有已知类测试集中进行测试。

3.2 FSCIL的后向兼容性训练

知识蒸馏

交叉熵损失和知识蒸馏损失:

首先是正常网络输入使用的交叉熵损失函数,另一个知识蒸馏损失针对的是之前类的蒸馏损失,相当于软标签与网络输出的交叉熵,知识蒸馏的关键点就是采用原有模型的输出作为软标签与当前网络输出计算交叉熵损失,具体可以了解知识蒸馏的相关概念。

【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 - 知乎 (zhihu.com)

原型网络

由于知识蒸馏对于少样本输入容易造成过拟合,因此通过原型网络克服过拟合现象,通过固定的嵌入函数提取增量阶段每一类的平均嵌入:

这里相当于通过计算每个类的平均嵌入用来代表每个类的权重,即 w i = p i w_i=p_i wi=pi.
论文中提到以上两种处理方式只是考虑了后向兼容性,即增加新类时注重平衡旧类以防止遗忘问题,因此引出前向兼容性,网络能够前瞻性的预示新类的出现从而做好准备。

4. FSCIL的前向兼容训练

论文将前向兼容的描述单独用一个小节描述,两个重要创新点:1. 为了模型能够生长,优化后验概率为双峰分布,即除了真实标签(Ground Truth)外,还有另一个其他类;2. 使用实例混合的方法使得可以预见新类的可能性分布。两点目的是使得模型能够在增量训练过程中具有前瞻性的能力。

4.1 用虚拟原型进行预训练

分配虚拟原型

第一项是网络输出与真实标签的损失,第二项是虚拟标签与带有掩膜函数的损失,其中 f v ( x ) = [ W , P v ] T ϕ ( x ) f_v({\rm{x}})=[W,P_v]^{\textsf{T}}\phi(\textrm{x}) fv(x)=[W,Pv]Tϕ(x),其中 W W W指的是类别原型矩阵, P v P_v Pv是嵌入空间中的虚拟原型矩阵: P v = [ p 1 , ⋯   , p V ] ∈ R d × V P_v=[p_1,\cdots,p_V]\in \mathbb{R}^{d\times V} Pv=[p1,,pV]Rd×V V V V是虚拟类数量, y y y是真实标签, y ^ \hat{y} y^是虚拟标签: y ^ = arg max ⁡ v p v T ϕ ( x ) + ∣ Y 0 ∣ \hat{y}=\argmax_v\bm{p}_v^{\textsf{T}}\phi(\textrm{x})+|Y_0| y^=vargmaxpvTϕ(x)+Y0

这里主要提出虚拟原型的目的是使得输出分布为双峰分布,第一项使得实例趋向于真实簇,第二项使得趋向于最近的虚类,如下入所示。

预测虚拟实例

在这里插入图片描述

其中 z \bm{z} z虚拟实例,通过mixup混合实例得到,原文: z = g [ λ h ( x i ) + ( 1 − λ ) h ( x j ) ] {\bm{z}}=g[\lambda h(\textrm{x}_i)+(1-\lambda)h(\textrm{x}_j)] z=g[λh(xi)+(1λ)h(xj)],对于 L f ( z ) \mathcal{L}_f(\bm{z}) Lf(z)的定义和虚拟原型损失定义类似,其中第一项使得虚拟示例趋向于虚拟原型,第二项使得虚拟实例趋向于最相近的已知类。

4.2 使用虚拟原型进行增量推理

通过虚拟原型以及虚拟类别构建全概率公式预测推理当前的类别输出。

5. 实验

在这里插入图片描述

从实验效果中可以得到完成很好的表现

论文中在CIFAR100、CUB200-2011、miniImageNet完成实验验证。

训练细节:CIFAR100使用ResNet20结构,其他的使用ResNet18结构,所有的对比方法采用的一样的主干网络,使用带动量的SGD优化算法并使用学习率衰减策略等。

评价指标:每个增量阶段的平均准确率,本文采用的是Top1,同时还有遗忘现象的评估,目前主流的类增量学习的评价指标主要是这两个方面的衡量。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值