本篇文章发表于NeurIPS 2022。
文章链接:https://arxiv.org/abs/2209.09056
一、概述
作者在Abstract中指出,当前的concept bottleneck models(CBMs)无法在high accuracy、robust concept-based explanations以及effective interventions on concepts之间达到最优的权衡,尤其对于现实世界中缺乏对concepts的supervision的情况,使我们无法很好地在test time进行有效的intervention(因为不知道如何干预)。
在Introduction部分作者进一步陈述了CBMs的特点和存在的问题。首先,由于CBMs考虑了concept accuracy,这有可能会以牺牲task accuracy为代价,使得最终的预测准确率降低,尤其是当提供的概念集并不包含准确预测下游任务所需的所有必要信息时(概念集是必要不充分的,实际上也很难做到充分。比如我们不能只根据颜色是黄色来分辨香蕉和玉米,但是又很难提供关于香蕉或者玉米的完整信息)。
本文提出了concept embedding models即CEM,一种新的CBM,通过学习可解释概念的高维representation,打破了准确性-可解释性之间的compromise,在intervention时的表现超越了CBMs,并且适用于缺乏概念监督的现实世界情况。
实验表明,CEM有以下特点:
(1) 与没有concepts的标准神经网络相比精度更高;
(2) 能够捕获meaningful semantics,包括并不限于ground-truth concept labels;
(3) 支持test-time intervention,对测试精度的影响超过了CBMs;
(4) 适用于缺乏完整概念监督的real-world conditions。
二、方法
与CBMs使用简单的二值concept不同,CEM使用的是一个高维的embedding来表示每个concept,这样allows for extra supervised learning capacity. 这也是处理概念集不充分情况的关键。
(i) Architecture
对于每个concept,CEM同时学习两个embeddings(with explicit semantics representing the concept's activity,即,具有表示概念“活跃情况”的显式语义);简而言之,就是CEM可以学习两个embedding:一个支持概念存在,一个反对概念存在;在test-time可以通过两种 switch between the two embedding states来实现intervention。
两个embeddings表示为 ,其中
代表concept的active state (concept is 'true',支持概念存在)、
代表concept的inactive state (concept is 'false',反对概念存在);基于此,首先使用一个DNN
学习latent representation
,接着
将作为embedding generator的输入用以产生对应的两个embeddings in
,即:
可以发现此处的generator使用的是简单的one-layer neural network,目的是防止当我们拥有规模比较大的概念时模型的参数过多。(当然可以使用更复杂的网络以生成embeddings)
设计一个learnable and differentiable的scoring function 以鼓励生成的两个embeddings
与
与ground-truth concept
相匹配(aligned)。
具体做法是,定义probability of concept being active from the embeddings' joint space:
For parameter efficiency, and
are shared across all concepts.
The final concept embedding:
直观地说,这样的设计有两个目的:(1) 当第 个concept
存在时(active,
),此时受损失函数(CrossEntropy)限制,
也接近于1,那么得到的
倾向于只依赖
;反之如果第
个concept
不存在时(inactive,
),此时
也接近于0,那么得到的
倾向于只依赖
。这样一来就促使模型生成了两个具有不同语义的latent space;并且,也提供了intervention时的方式,即,将两个embedding进行switch。
之后,将所有concept( 个类别)对应的concept embeddings组合起来(concatenate),得到
with
units,作为bottleneck输入到最后的预测网络
以获得下游任务的标签
。此处
选择的是a simple linear layer,当然可以使用更复杂的function。类似于CBM,CEM也可以提供concept-based explanation for the output of
through the concept probability vector
,代表了每一个concept的“活跃度”。
整个模型可以通过SGD进行端到端的训练,损失函数同时考虑了task accuracy与concept accuracy,用超参数 进行权衡:
(ii) Interventing with Concept Embeddings
与CBM相同,CEM也支持test-time intervention。具体做法是在测试时对产生的 进行干预,使其变为ground-truth即
;而
是由
计算得到的,因此实际操作中是改变
的值。举个例子,比如CEM预测的
,但人类专家知道此时
,因此采用intervention将
强制置为1即可。所谓的switch就体现在此,原本是:
而现在变成了:
此外,作者引入了一种规范化策略RandInt,使CEM可以在训练时就接受干预,以提高测试时采用干预的有效性。(通过RandInt提高了test-time intervention的有效性,一个比较新颖的思路,先让模型“适应”一下intervention这个操作)
具体来说, 会以概率
被置为 ground-truth
,而以
的概率保持预测值
这个策略强制来自下游任务的feedback只用来更新correct concept embedding,而来自concept的预测对两个embedding同时进行更新。作者将这种做法类比为dropout,因为一部分concept representation仅使用来自其概念标签的feedback进行训练,而另一些concept representation同时接受来自其概念和任务标签的feedback。
(iii) Evaluating concept bottleneck
提出Concept Alignment Score (CAS)用以评估学习到的concept roncepts与clusters做了归一化,将CAS归一化到0到1之间。
具体来说,对于每一个concept,CAS会使用一种聚类算法寻找 个簇,并根据每个样本
生成的第
个concept representations
进行聚类,为
分配一个簇标签
,然后对样本ground-truth label与聚类分配的簇标签计算条件熵,如果条件熵越小,并由条件熵定义一个homogeneity score,如果该分数越高说明学到的concept representation与ground-truth concept匹配得越好。
简单理解,当具有同一个ground-truth concept的样本的concept representation在对应的latent space中的分布也相近,并且与其它类别分的比较开的时候,就代表这个representation学的比较好。
随后,作者基于信息论研究了information bottleneck,推测基于embedding的CBM通过充分保留来自于输入分布的信息来避免发生information bottleneck;而scalar-based concept representations会被迫在concept层面压缩来自输入的信息,从而造成了信息瓶颈,导致了输入和概念层、概念层和输出之间的compromise。
简单来说,就是用一个值表示concept(scalar-based representations)会造成信息的丢失(因为一个值能提供的信息容量有限),而高维的embedding可以保留来自输入的信息,避免信息瓶颈的发生。更多细节请查阅原论文。
三、实验及结果
- 蓝色点代表CEM,task accuracy很高;
- CAS分数也很高,证明学习到的representation能够与concept对齐。
- 在训练过程中,蓝色(CEM)不会存在information bottleneck的问题——即,随着训练的迭代,模型不断压缩来自输入数据的信息,造成遗忘。
- 与使用scalar concept representations的CBMs相比,基于embedding的表示能够保持高的task accuracy和mean concept accuracy。额外的容量(高维度的embedding)允许CBMs在不过度约束概念表示的情况下最大化概念准确性,从而允许有用的输入信息通过。“In CEMs all input information flows through concepts, as they supervise the whole concept embedding.”
- CEM的embedding space形成了两个明显的簇。
- 右边模型(Hybrid)的latent space无法像CEM那样清晰地捕获概念语义。
- 可以通过查询the nearest Euclidean neighbours来观察相近的图片(图c)。
- 1,2模拟正确干预;3,4模拟错误干预(模拟人类的失误,检验模型的容错性);
- CEM对intervention高度响应,并且RandInt可以使响应更显著(深蓝色,斜率更大);
- 与bool-CBM和fuzzy-CBM相比,CEM对人类的错误干预更有鲁棒性(斜率更小);
- hybrid-CBM(粉色)几乎不受干预影响(斜率趋近于0),即使它对人类的错误intervention也没什么响应,但是这是因为它对什么样的干预都没反应,而不是有鲁棒性。