【Few-Shot Incremental Learning】Semantics-Driven Generative Replay in ACM MM 2022 个人理解

该博客围绕少样本类增量学习展开,介绍了Semantics - Driven Generative Replay方法。在初始训练集上,将其分段模拟少样本场景,引入特征提取器、语义映射模块、生成对抗网络和四项损失;在增量训练集上,调整损失,引入蒸馏损失避免遗忘,扩充原型实现增量学习。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、简介

题目: Semantics-Driven Generative Replay for Few-Shot Class Incremental Learning
会议: ACM MM 2022
任务🎯: 给定一批样本充足的初始数据,之后新的数据/任务依次到来,新的数据/任务到来后旧的就不再可获取,并且新的数据相比初始数据可能是稀缺的(例如,初始类别每类可能有100个样本,但新的类别每类可能只有5个样本),要求模型在学习新数据/任务的同时不要忘记旧数据/任务。
Idea✨:
(1)使模型提前适应少样本增量场景,并引入四项损失进行优化: 首先,将初始样本充足的训练数据分割降采样为少样本的增量形式;其次,引入标签语义映射模块降低生成对抗模型对数据量的依赖;最后,引入四个损失在分割的几段数据上进行模型训练,通过 L G A N \mathcal{L}_{GAN} LGAN保证生成器生成的假样本与真样本相似,通过 L V 2 S \mathcal{L}_{V2S} LV2S保证生成的假样本与同类的语义原型接近,通过 L M C \mathcal{L}_{MC} LMC避免模式崩溃,通过 L C L S \mathcal{L}_{CLS} LCLS保证模型分类能力。
(2)通过知识蒸馏来避免灾难性遗忘,扩充原型并更新模型实现增量学习: 首先,对于旧类原型,结合均方误差和KL散度计算蒸馏损失 L D I S \mathcal{L}_{DIS} LDIS以避免灾难性遗忘;其次,扩充新类原型,重新计算在新类上的 L G A N \mathcal{L}_{GAN} LGAN L M C \mathcal{L}_{MC} LMC L C L S \mathcal{L}_{CLS} LCLS以更新模型。
Note⚠️:ALICE in ECCV 2022和SAVC in CVPR 2023把重点放在训练一个强大的具有泛化能力的特征提取器上不同,该工作将重点放在了如何在少样本场景中使用生成对抗网络上。
初始训练
如图,作者在初始训练阶段引入了四个损失以确保生成器在少样本场景下的学习能力。

二、详情

1. 在初始训练集上的学习过程

初始训练集样本充足,为使模型提前适应少样本增量场景,作者将初始训练集分为了多段,每段是随机抽取的 K K K个类别和 N N N个样本,形成 K K K-way N N N-shot,不同段之间类别不重叠。第 q q q个数据段记作 S ( q ) \mathcal{S}(q) S(q)

1.1 特征提取器

因为作者希望将初始训练集分段以提前模拟少样本增量场景,所以直接使用的是预训练过的特征提取器。具体来说,就是在ImageNet上预训练过的ResNet-18,记作 F \mathcal{F} F。在分段的初始训练集上,特征提取器会被微调。

通过特征提取器可以提取特征并计算各类特征的均值,计算过程如下:

其中, a i = F ( x i ) a_i=\mathcal{F}(x_i) ai=F(xi)为特征; y i y_i yi为对应的标签; s k q s_k^q skq为第 q q q段数据第 k k k个类的原型, { s 1 q , s 2 q , ⋯   , s K q } \{s_1^q,s_2^q,\cdots,s_K^q\} {s1q,s2q,,sKq}会被保存下来。

1.2 语义映射模块

少样本不足以支撑生成对抗网络(Generative Adversarial Network, GAN)的学习,于是作者引入语义映射模块辅助GAN的学习。语义映射模块是一个图卷积网络(Graph Convolutional Network, GCN),记作 P \mathcal{P} P

首先,我们需要了解图神经网络(Graph Neural Network, GNN),假设有如下关系:

如图,展示了GNN的两个重要输入,不同类别的节点特征 a 1 , ⋯   , a 5 a_1,\cdots,a_5 a1,,a5,不同类别节点间的邻接关系矩阵(有关则为1,无关则为0)。

我们可以简单的将GNN视为一个特征更新模型,GNN认为某一类别的节点特征是与邻接类的节点特征相关的。对于 a 1 a_1 a1来说,一次更新过程如下:

a 1 ′ = σ ( W ( a 1 + β 1 × a 4 + β 2 × a 5 ) ) ) a^{\prime}_1=\sigma(W(a_1+\beta_1\times a_4+\beta_2\times a_5))) a1=σ(W(a1+β1×a4+β2×a5)))

其中, β 1 \beta_1 β1 β 2 \beta_2 β2可以是人工设定的权重, W W W是需要训练的参数, σ \sigma σ是激活函数(ReLU)。可以看到,经过一次更新每个类的节点特征都包含了自身和邻接类的节点特征,例如 a 1 ′ a^{\prime}_1 a1包含了 { a 1 , a 4 , a 5 } \{a_1,a_4,a_5\} {a1,a4,a5} a 4 ′ a^{\prime}_4 a4包含了 { a 1 , a 2 , a 3 , a 4 , a 5 } \{a_1,a_2,a_3,a_4,a_5\} {a1,a2,a3,a4,a5}。如果再进行一次更新,则 a 1 ′ ′ a^{\prime\prime}_1 a1′′会包含 { a 1 ′ , a 4 ′ , a 5 ′ } \{a^{\prime}_1,a^{\prime}_4,a^{\prime}_5\} {a1,a4,a5},也就包含了所有类的原始特征。

但是,这样做有一个问题,那就是 a 3 a_3 a3虽然与 a 4 a_4 a4相关,但明显不具备 a 4 a_4 a4那样庞大的关系网,如果不设置合理的 β 1 \beta_1 β1 β 2 \beta_2 β2,GNN的更新会使 a 3 a_3 a3 a 4 a_4 a4差异越来越小,这是我们不希望看到的。

GCN解决了这一问题,其更新公式如下:

H ′ = σ ( W ( D ˜ − 1 2 A ˜ D ˜ − 1 2 H ) ) H^{\prime}=\sigma\left(W\left(\~D^{-\frac{1}{2}}\~A\~D^{-\frac{1}{2}}H\right)\right) H=σ(W(D˜21A˜D˜21H))

其中, H H H H ′ H^{\prime} H是更新前后所有类的节点特征矩阵; A ˜ \~A A˜是各类间的邻接关系矩阵; D ˜ \~D D˜矩阵仅对角线有值(值为 A ˜ \~A A˜的行和,即 D ˜ i i = ∑ j A ˜ i j \~D_{ii}=\sum_j\~A_{ij} D˜ii=jA˜ij),其余为0; W W W是需要训练的参数, σ \sigma σ是激活函数(ReLU)。对于 a 1 a_1 a1,其计算过程如下:

a 1 ′ = σ ( W ( ∑ j 1 D ˜ i i D ˜ j j A ˜ i j a j ) ) = σ ( W ( 1 3 × 3 a 1 + 1 3 × 5 a 4 + 1 3 × 3 a 5 ) ) \begin{aligned} a^{\prime}_1=&\sigma\left(W\left(\sum_j\frac{1}{\sqrt{\~D_{ii}\~D_{jj}}}\~A_{ij}a_j\right)\right)\\ =&\sigma\left(W\left(\frac{1}{\sqrt{3\times 3}}a_1+\frac{1}{\sqrt{3\times 5}}a_4+\frac{1}{\sqrt{3\times 3}}a_5\right)\right) \end{aligned} a1==σ W jD˜iiD˜jj 1A˜ijaj σ(W(3×3 1a1+3×5 1a4+3×3 1a5))

可以看到, 1 / D ˜ i i D ˜ j j 1/\sqrt{\~D_{ii}\~D_{jj}} 1/D˜iiD˜jj 是GCN指定的权重,该权重会根据当前节点的关系网调整强度,关系越多,分母越大,权重越小。

综上,语义映射模块 P \mathcal{P} P的输入有两个:节点特征和邻接关系矩阵。节点特征可以使用当前阶段的各类别的原型 { s 1 q , s 2 q , ⋯   , s K q } \{s_1^q,s_2^q,\cdots,s_K^q\} {s1q,s2q,,sKq},邻接关系矩阵作者则是采用阈值化词嵌入相似度的方式获取。词嵌入就是将类别标签转换为一个向量,转换模型是在某个词库上预训练好的Word2Vec模型。这样,每个类标签都变成了一个向量,向量两两之间可以计算余弦相似度,对相似度设定一个阈值,超过阈值的为1,为超过的为0,则可以形成邻接关系矩阵。

通过GCD,输入上述节点特征和邻接关系矩阵,即可按照更新公式更新各类别的节点特征。之后,令GAN生成的特征向同类别的节点特征聚集并远离其他节点特征即可达到辅助GAN学习的目的。

1.3 生成对抗网络

作者使用GAN的目标是引入蒸馏损失保证当前GAN模型与上一阶段的GAN模型所生成旧类假样本的一致性,从而避免灾难性遗忘。作者使用的是条件GAN(Conditional GAN, CGAN)的训练策略,但不完全相同。

首先,我们需要了解经典的GAN,其结构图如下:

如图,GAN主要包括两个关键结构:生成器 G \mathcal{G} G和判别器 D \mathcal{D} D z z z为噪声向量。

GAN通过生成器与判别器的对抗实现生成接近真实样本的假样本的目标。一方面GAN希望输入 z z z后生成器生成的样本 G ( z ) \mathcal{G}(z) G(z)与真实样本 x x x接近,最好能够骗过判别器;另一方面GAN希望判别器具备辨别真假的能力,最好不要被判别器欺骗。因此,GAN的优化策略如下:

其中, p d a t a ( x ) p_{data}(x) pdata(x) p z ( z ) p_z(z) pz(z)分别是输入样本和噪声所服从的分布; G ( z ) \mathcal{G}(z) G(z)为生成器生成的样本, D ( ∗ ) \mathcal{D}(*) D()是指判别器作出的判别(真或假)。我们将 V ( D , G ) V(\mathcal{D},\mathcal{G}) V(D,G)写成下面的样子将更容易理解:

对于判别器来说, V ( D , G ) V(\mathcal{D},\mathcal{G}) V(D,G)中前后两项都与 D \mathcal{D} D有关,并且判别器希望判别 x i x_i xi为真,判别 G ( z i ) \mathcal{G}(z_i) G(zi)为假。所以最理想的情况是:

此时, L o s s D Loss_\mathcal{D} LossD取最大值0,其余情况均小于0,所以优化策略中有 max ⁡ D \max\limits_\mathcal{D} Dmax

对于生成器来说, V ( D , G ) V(\mathcal{D},\mathcal{G}) V(D,G)中仅后面一项与 G \mathcal{G} G有关,它希望 D \mathcal{D} D G ( z i ) \mathcal{G}(z_i) G(zi)判别为真。所以最理想的情况是:

此时, L o s s G Loss_\mathcal{G} LossG取最小值,其余情况均大于 − ∞ -\infty ,所以优化策略有 min ⁡ G \min\limits_\mathcal{G} Gmin

但是经典GAN有一个问题,因为没有标签的参与,输入一个噪声即得到一个假样本,但该假样本属于哪个类我们并不知道。

CGAN解决了这一问题,其结构图如下:

如图,CGAN将标签转化为词嵌入向量,然后分别与噪声向量、生成的假样本、真实样本拼接进行模型训练,其余部分均与GAN相同。这样,当我们希望生成某一类别的样本时,输入噪声和对应类别的词嵌入向量即可。

因此,CGAN的优化策略为如下形式:

与GAN的优化策略区别仅在于有无 ∣ y \color{red}|y y

不过,作者并没有输入词嵌入向量,而是用同样能够代表类别的原型 { s 1 q , s 2 q , ⋯   , s K q } \{s_1^q,s_2^q,\cdots,s_K^q\} {s1q,s2q,,sKq}

于是,得到如下优化策略:

其中, s q = { s 1 q , s 2 q , ⋯   , s K q } \textbf{s}_q=\{s_1^q,s_2^q,\cdots,s_K^q\} sq={s1q,s2q,,sKq} z ∼ N ( 0 , 1 ) z\sim\mathcal{N}(0,1) zN(0,1)。由于作者将GAN放在了特征提取器之后,因此这里的GAN生成的是特征而不是图像,因此CGAN中的 x x x在这成了 a a a

这样,输入一个类别原型和噪声,生成器即可生成一个对应类的特征,并且该特征与对应类的真实特征相似。如果前一阶段的生成器和当前的生成器输入相同的旧类原型和噪声后输出能够保持一致,则说明模型没有发生灾难性遗忘。

1.4 损失

1) L G A N \mathcal{L}_{GAN} LGAN
生成对抗损失,公式如下:

该损失能够保证生成器生成的特征质量,即与真实的特征相似,并且不同类别的生成特征是有区分的。
2) L V 2 S \mathcal{L}_{V2S} LV2S
语义映射损失,公式如下:

其中, w k w_k wk为第 k k k个类的标签词向量, P ( w k ) \mathcal{P}(w_k) P(wk)为通过语义映射模块输出的第 k k k类的节点特征; s ^ k j \hat{s}_k^j s^kj则是输入原型 s k q s^q_k skq和噪声 z j z_j zj后生成器生成的假特征;噪声 z 1 z_1 z1 z 2 z_2 z2是从 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1)中两次随机采样获得,进行两次采样是为了引入 L M C \mathcal{L}_{MC} LMC以避免模式崩溃。

该损失中前一项能够使生成的假特征 s ^ k j \hat{s}_k^j s^kj向同一类别的节点特征 P ( w k ) \mathcal{P}(w_k) P(wk)聚集,并远离其他节点特征 P ( w m ) , m ≠ k \mathcal{P}(w_m),m\neq k P(wm),m=k;后一项能够使同一类的 s ^ k j \hat{s}_k^j s^kj P ( w k ) \mathcal{P}(w_k) P(wk)进一步聚集。
3) L M C \mathcal{L}_{MC} LMC
反模式崩溃损失,公式如下:

其中, ϵ \epsilon ϵ用来避免分母为0; z 1 z_1 z1 z 2 z_2 z2是从 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1)中两次随机采样获得的噪声; s ^ 1 \hat{s}^1 s^1 s ^ 2 \hat{s}^2 s^2是分别输入 z 1 z_1 z1 z 2 z_2 z2和同一个原型后生成器生成的两个假特征,它们与原型属于同一类别。

z 1 z_1 z1 z 2 z_2 z2彼此接近时,对应生成器生成的 s ^ 1 \hat{s}^1 s^1 s ^ 2 \hat{s}^2 s^2会趋于一致,该问题被称为模式崩溃。

上式中, z 1 z_1 z1 z 2 z_2 z2越接近,分子越小; s ^ 1 \hat{s}^1 s^1 s ^ 2 \hat{s}^2 s^2越远离,分母越大。因此该损失能够使 s ^ 1 \hat{s}^1 s^1 s ^ 2 \hat{s}^2 s^2 z 1 z_1 z1 z 2 z_2 z2接近的情况下仍能有明显区别。
4) L C L S \mathcal{L}_{CLS} LCLS
交叉熵分类损失,公式如下:

L C L S = ∑ k = 1 K p k log ⁡ ( p ^ k ) \mathcal{L}_{CLS}=\sum_{k=1}^{K}p_k\log(\hat{p}_k) LCLS=k=1Kpklog(p^k)

其中, p k p_k pk p ^ k \hat{p}_k p^k分别是在第 k k k个类上的真实标签和预测概率。

该损失在SoftMax分类器 M \mathcal{M} M之后,分类器在特征提取器 F \mathcal{F} F之后。经该损失计算的特征包括真实样本的特征 a i a_i ai和生成器生成的特征 s ^ 1 \hat{s}^1 s^1 s ^ 2 \hat{s}^2 s^2

该损失能够使不同类别的特征区分更加明显。

整合上述损失即可形成如下整体损失:

其中, α 1 \alpha_1 α1 α 2 \alpha_2 α2为权重超参数。

训练完成后,我们可以得到特征提取器 F \mathcal{F} F(仅微调,之后冻结)、生成器 G \mathcal{G} G、判别器 D \mathcal{D} D、分类器 M \mathcal{M} M、原型集 { s 1 , s 2 , ⋯   , s ∣ C b a s e ∣ } \{s_1,s_2,\cdots,s_{|C_{base}|}\} {s1,s2,,sCbase}(|C_{base}|是初始训练集总类别数)。

2. 在增量训练集上的学习过程

因为初始训练集已经模拟了增量训练集的少样本增量条件,因此大部分损失做简单调整后是可以直接迁移的,不过其中 L V 2 S \mathcal{L}_{V2S} LV2S不再使用,作者的意思是通过初始训练集的学习模型已经具备了在少样本条件下学习的能力。

2.1 蒸馏损失

为了避免灾难性遗忘,作者引入蒸馏损失 L D I S \mathcal{L}_{DIS} LDIS,它为上一阶段和最新的模型输入相同的旧原型判断输出的一致性,输出一致则说明未发生灾难性遗忘。

假设在初始训练集上的学习过程 t = 0 t=0 t=0,每一次 K K K-way N N N-shot的增量训练集都使 t = t + 1 t=t+1 t=t+1,便有如下蒸馏损失:

其中, k k k从1取到 ∣ C b a s e ∣ + K ( t − 1 ) |C_{base}|+K(t-1) Cbase+K(t1),表示仅需保证模型对旧原型的输出一致性。

上述蒸馏损失前一部分能够保证最新的生成器 G t \mathcal{G}_t Gt和上一阶段的生成器 G t − 1 \mathcal{G}_{t-1} Gt1在喂入相同的噪声和旧原型时输出能够尽可能一致;后一部分则是用来保证输出假特征的分布能够尽可能一致(作者使用的 Q \mathcal{Q} Q是多元高斯分布)。

2.2 增量扩展

假设当前阶段 K K K-way N N N-shot的增量训练集为 A ( t ) \mathcal{A}(t) A(t),则分类器 M \mathcal{M} M可以增加 K K K个神经元来纳入新类。结合之前的几个损失(不包括 L V 2 S \mathcal{L}_{V2S} LV2S),可以得到如下综合损失:

其中, α 1 \alpha_1 α1 α 2 \alpha_2 α2为权重超参数; s 0 : t − 1 \textbf{s}^{0:t-1} s0:t1 ∣ C b a s e ∣ + K ( t − 1 ) |C_{base}|+K(t-1) Cbase+K(t1)个旧原型。

学习开始前,新类的 K K K个原型会被计算并与旧原型合并。在学习新类时, M \mathcal{M} M中与旧类相关的网络权重被冻结。

总的来说,该工作重心在使GAN能够在少样本场景下学习上,引入语义映射模块是主要贡献,灾难性遗忘通过原型扩充、保留和蒸馏损失避免。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Fulin_Gao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值