【Few-Shot Incremental Learning】Semantics-Driven Generative Replay in ACM MM 2022 个人理解

该博客围绕少样本类增量学习展开,介绍了Semantics - Driven Generative Replay方法。在初始训练集上,将其分段模拟少样本场景,引入特征提取器、语义映射模块、生成对抗网络和四项损失;在增量训练集上,调整损失,引入蒸馏损失避免遗忘,扩充原型实现增量学习。

一、简介

题目: Semantics-Driven Generative Replay for Few-Shot Class Incremental Learning
会议: ACM MM 2022
任务🎯: 给定一批样本充足的初始数据,之后新的数据/任务依次到来,新的数据/任务到来后旧的就不再可获取,并且新的数据相比初始数据可能是稀缺的(例如,初始类别每类可能有100个样本,但新的类别每类可能只有5个样本),要求模型在学习新数据/任务的同时不要忘记旧数据/任务。
Idea✨:
(1)使模型提前适应少样本增量场景,并引入四项损失进行优化: 首先,将初始样本充足的训练数据分割降采样为少样本的增量形式;其次,引入标签语义映射模块降低生成对抗模型对数据量的依赖;最后,引入四个损失在分割的几段数据上进行模型训练,通过 L G A N \mathcal{L}_{GAN} LGAN保证生成器生成的假样本与真样本相似,通过 L V 2 S \mathcal{L}_{V2S} LV2S保证生成的假样本与同类的语义原型接近,通过 L M C \mathcal{L}_{MC} LMC避免模式崩溃,通过 L C L S \mathcal{L}_{CLS} LCLS保证模型分类能力。
(2)通过知识蒸馏来避免灾难性遗忘,扩充原型并更新模型实现增量学习: 首先,对于旧类原型,结合均方误差和KL散度计算蒸馏损失 L D I S \mathcal{L}_{DIS} LDIS以避免灾难性遗忘;其次,扩充新类原型,重新计算在新类上的 L G A N \mathcal{L}_{GAN} LGAN L M C \mathcal{L}_{MC} LMC L C L S \mathcal{L}_{CLS} LCLS以更新模型。
Note⚠️:ALICE in ECCV 2022和SAVC in CVPR 2023把重点放在训练一个强大的具有泛化能力的特征提取器上不同,该工作将重点放在了如何在少样本场景中使用生成对抗网络上。
初始训练
如图,作者在初始训练阶段引入了四个损失以确保生成器在少样本场景下的学习能力。

二、详情

1. 在初始训练集上的学习过程

初始训练集样本充足,为使模型提前适应少样本增量场景,作者将初始训练集分为了多段,每段是随机抽取的 K K K个类别和 N N N个样本,形成 K K K-way N N N-shot,不同段之间类别不重叠。第 q q q个数据段记作 S ( q ) \mathcal{S}(q) S(q)

1.1 特征提取器

因为作者希望将初始训练集分段以提前模拟少样本增量场景,所以直接使用的是预训练过的特征提取器。具体来说,就是在ImageNet上预训练过的ResNet-18,记作 F \mathcal{F} F。在分段的初始训练集上,特征提取器会被微调。

通过特征提取器可以提取特征并计算各类特征的均值,计算过程如下:

其中, a i = F ( x i ) a_i=\mathcal{F}(x_i) ai=F(xi)为特征; y i y_i yi为对应的标签; s k q s_k^q skq为第 q q q段数据第 k k k个类的原型, { s 1 q , s 2 q , ⋯   , s K q } \{s_1^q,s_2^q,\cdots,s_K^q\} { s1q,s2q,,sKq}会被保存下来。

1.2 语义映射模块

少样本不足以支撑生成对抗网络(Generative Adversarial Network, GAN)的学习,于是作者引入语义映射模块辅助GAN的学习。语义映射模块是一个图卷积网络(Graph Convolutional Network, GCN),记作 P \mathcal{P} P

首先,我们需要了解图神经网络(Graph Neural Network, GNN),假设有如下关系:

如图,展示了GNN的两个重要输入,不同类别的节点特征 a 1 , ⋯   , a 5 a_1,\cdots,a_5 a1,,a5,不同类别节点间的邻接关系矩阵(有关则为1,无关则为0)。

我们可以简单的将GNN视为一个特征更新模型,GNN认为某一类别的节点特征是与邻接类的节点特征相关的。对于 a 1 a_1 a1来说,一次更新过程如下:

a 1 ′ = σ ( W ( a 1 + β 1 × a 4 + β 2 × a 5 ) ) ) a^{\prime}_1=\sigma(W(a_1+\beta_1\times a_4+\beta_2\times a_5))) a1=σ(W(a1+β1×a4+β2×a5)))

其中, β 1 \beta_1 β1 β 2 \beta_2 β2可以是人工设定的权重, W W W是需要训练的参数, σ \sigma σ是激活函数(ReLU)。可以看到,经过一次更新每个类的节点特征都包含了自身和邻接类的节点特征,例如 a 1 ′ a^{\prime}_1 a1包含了 { a 1 , a 4 , a 5 } \{a_1,a_4,a_5\} { a1,a4,a5} a 4 ′ a^{\prime}_4 a4包含了 { a 1 , a 2 , a 3 , a 4 , a 5 } \{a_1,a_2,a_3,a_4,a_5\} { a1,a2,a3,a4,a5

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Fulin_Gao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值