counterfactual generation network——因果推断用于提高分类器鲁棒性

本文探讨了如何利用因果推断原理改进图像分类任务。针对传统分类器易受图像背景和纹理干扰的问题,提出一种方法将图像分解为形状、纹理和背景三个独立组件,并借助BigGAN生成反事实图像。通过训练一个多头分类器来增强模型的鲁棒性,使其能够关注于真正的分类依据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

因果推断本是一个小众方向,但是感觉近几年突然火了起来,本文便是2021年发表于ICLR的一篇将因果推断用于提高图像分类的文章。

本文出发点

先说一下本文针对的问题:我们知道图像分类任务就是将输入网络的图像尽量映射成我们想要的输出结果(通常是一个one-hot编码)。网络最终输出的label,是图像的多个特征共同作用的结果,因此对于通常的分类器,我们并不能保证其分类标准是按照我们人类想的那样进行的。例如说我们想对一头奶牛进行分类,我们的训练图像多是黑白相间的奶牛站在草地上,虽然我们人去判断图像是否是一只奶牛的时候是按照这个目标的外貌轮廓来判断的,但是比起这个,去判断这个图像的纹理是否是黑白相间/背景是否是草地显然比去提取目标轮廓要轻松的多。虽然这样训练出来的网络在面对和训练集同分布的测试集时是没有太大问题的,但是这并不是我们想要的结果:我们希望在经过训练后,即使给网络一个站在树上的,紫色的奶牛,网络依然可以判断出它是个奶牛(或者说至少可以判断出它是个牛),但是一旦网络在训练时抄了近道(shortcut learning)将主要的判断依据归为了目标纹理和目标背景,而忽略了真正重要但是难以学习的目标轮廓,那么网络显然难以判断上述图像是否是个牛。因此如何避免网络将目标和一些无关信息建立虚假的相关性,就是本文想要解决的问题。
我们知道因果推断的一个核心观点就是区分虚假的相关性和真实的因果关系,因此,作者从因果的角度出发,解决上述问题:在因果推断里面有一个重要的概念叫做可交换性(exchangeability)它是说当一个操作(treatment)仅仅对目标有影响,而不受其他影响目标的操作的影响时,我们就可以将这个操作和目标的概率关系等价于二者的因果关系,也就是这个目标的产生是由这个操作导致的。因此要想使得分类和我们想让分类器注意的目标产生直接的因果关系,我们就要将其和其他我们认为无关紧要的因素解纠缠,或者我们用文中的话来说,让每个因素成为一个独立的组件(independent mechanisms)。

具体实施过程

作者通过上面所提到的分类结果被图像背景/纹理所干扰的事实出发,将图像分解为三个独立的组件:object’s shape, the object’s texture and the object’s background。然后通过特殊的生成器去分别生成这三个组件,最后再将其拼装起来,就可以获得一张新的图像。由于这三个组件各自是独立的,那么我们在生成这三个组件后,可以将其任意组合,显然我们的组合结果会千奇百怪,因此作者个这些组合结果一个名字——反事实图像(counterfactual images)。之后作者提出用这些生成图像去训练一个三头分类器,每个头去判断其中的一个因素,将会极大的提高分类的鲁棒性。在这里插入图片描述
可以看到对于普通的鸵鸟图像,多头分类器的三个输出都是鸵鸟,那么我们无法去肯定我这三个头是根据图像的三个组件来判断出的结果,然而对于反事实鸵鸟图像,多头分类器应该输出shape:鸵鸟,texture:草莓,backgorund:水里,那么如果我们人觉得判断一个东西是不是鸵鸟的依据应该是它是否有鸵鸟的轮廓,则我们应该取shape的输出作为整个多头分类器的最终输出。因此,这篇文章的一个贡献就是将独立图像组件生成与鲁棒分类联系在了一起,使得分类器可达到域外鲁棒(对于训练集里没有出现的数据也有着很好的鲁棒性)。
多头分类器无须进一步介绍,那么接下来我们重点来看一下反事实图像是如何生成的
在这里插入图片描述
前面提到的三个组件,实际上我们可以将其看做是三张独立的特殊图像,而我们知道GAN有着强大的图像生成能力,对于GAN来说,要想控制生成结果,我们需要引入标签,也就是条件GAN,因此在反事实图像生成过程中,作者引入了目前最强大的条件GAN——BIgGAN来作为基本模型,并且这里要说一下,作者在ImageNet进行训练的时候,引入的BigGAN都是在ImageNet上已经预训练好的,这使得本模型的训练得到了极大的简化:首先作者通过在正态分布里随机采样一个 u u u,并且随机采样一个label y y y作为网络的初始输入,之后分为四个流程,注意,在流程图中所有的绿色网络表示在整个训练过程中参数固定不会被训练,蓝色网络表示参数可以被训练,其中的BIgGAN都是已经在ImageNet上预训练好的
一、最上面的通过BigGAN生成一张图像,此时这张图像会包含近乎完整的三个部分(shape,texture,background)信息,之后作者引入了一个预训练好的U2-Net网络,用来提取生成图像的object shape,这里简介一下U2-Net网络,这是一个类不可知的网络,其输入是一张任意图像,输出是它认为的图像中重要目标的轮廓,其中目标部分用1表示,背景部分用0表示。在这个流程中,我们需要优化两个损失:其一是 L b i n a r y L_{binary} Lbinary,通过使得U2-Net的输出(在这个流程中我们可以称之为mask)接近于1或者是0,来逐渐突出BigGAN生成图像的object。其二是 L m a s k L_{mask} Lmask,通过避免生成的mask全为1或者全为0,来避免模型的崩溃。在这个过程中BigGAN生成的图像object会被逐渐突出,而背景和texture则会被逐渐边缘化,最终U2-Net可以通过BIgGAN的生成,来提取一个完美的mask图像。在这里插入图片描述
二、第二个流程是为了生成object的texture。这里的生成方式很简单,就是直接用BIgGAN来生成,重点说一下这个流程的损失:这里作者为了使得生成的图像成为一张object的texture贴图,使用了一种很特殊的优化方式,通过上一个流程,我们可以获知哪些地方是object的内部和外部,则我们会选择这些内部部分的patch,去粘贴到外部部分,使得其patch粘贴后和粘贴前的感知损失最小化,这就可以迫使在这个流程中,BIgGAN的生成图像无论object内部还是object外部都充满着object的纹理。具体的, L t e x t ( f , pg ) L_{text} (f,\textbf{pg}) Ltext(f,pg)其中 f f f是BIgGAN直接输出的图像, pg \textbf{pg} pg是patch粘贴后的图像。在这里插入图片描述
三、第三个流程是为了生成一张object的background。这里和第一个流程类似,也是先用BIgGAN生成一张图像,然后输入到一个预训练的U2-Net中,用输出的object提取图来进行损失计算。但是与第一个流程不同的是,这里是逆用U2-Net,也就是要想U2-Net的object提取最小化。其实这个很好理解,为了提取图像的背景信息,我们需要尽可能的去掉object信息,而评价object信息是否被很好的去除(或者说被很好的用背景信息inpainted)我们可以用U2-Net的反映程度来测量。通过不断优化这个目标(文中用 L b g L_{bg} Lbg表示),最终BIgGAN的输出图像中object不断收缩,最终输出一张只包含背景的图像。在这里插入图片描述
四、就是一个训练好的,而且被固定参数的BIgGAN,只不过这个网络的输入高纬特征和输入的label和前面三个流程是一样的,因此这个网络输出的应该是前三个流程的输出结果的一个组合。这个流程不参与优化,只不过其输出会被用于后面总体的合成图像的重建误差计算。

补:这里还需要说一点,就是我们如何将前面三个流程的输出结果组合成一张完整的图像,也就是流程图中的 C C C C : = x g e n = C ( m , f , b ) = m ⊙ f + ( 1 − m ) ⊙ b C:= x_{gen} = C(m,f,b) = m \odot f + (1 − m) \odot b C:=xgen=C(m,f,b)=mf+(1m)b这里可以看出,C不包含任何参数,仅仅是三个图像的一种组合方法。
以上是整个反事实图像生成模型的训练方法,这里值得注意的是,在这个阶段我们输入到四个流程中的 u u u y y y在同一次训练时是相同的。而在应用阶段,由于我们仅仅需要 x g e n x_{gen} xgen,所以并没有这个限制。
通过以上方法,对于ImageNet(1000类),我们理论上可以生成 100 0 3 1000^3 10003种图像,并且如果我们默认分类标准与其中一个或两个组件无关时,我们可以通过对这些无关组建随机采样生成,来获得理论上无限的反事实图像。

实验分析

本文所做的数值实验一共有四个:
一、探索不同loss对实验结果的影响。通过上面的分析,我们可以看出反事实图像生成模型需要好多个损失共同优化,并且损失之间有着依赖关系,因此作者在这里首先通过消融实验来探索了各个损失对模型的作用在这里插入图片描述
其中IS值反映生成图像的多样性和真实性,在保证生成图像满足要求的前提下,越高越好, μ m a s k \mu_{mask} μmask代表的整张mask图的平均值,从直观来讲,显然这个值过大过小都表示了mask生成流程的崩溃,当生成的反事实图像几乎都是texture时,这个值会接近于0,当生成的反事实图像几乎都是background时,这个值则会接近于1。从上表可以看出,当 L s h a p e 或 L r e c L_{shape}或L_{rec} LshapeLrec没有得到优化时,模型的生成会变得无意义。而当 L t e x t L_{text} Ltext L b g L{bg} Lbg没有得到优化时,模型的这两个组件则会出现崩溃。
二、对比不同的伪图像生成方法的结果的质量。在这里插入图片描述
这里IRM和LNTL是两种不同的伪图像生成方法。可以看到,对于三种不同的mnist数据,CGN(即本文的方法)生成的伪图像训练出的分类器,其测试精度都是最高的,并且是提高幅度非常大。尽管它们在训练精度上略微有些下降。
三、对训练数据集中改变不同的反事实图像组件,来改变数据集的分布,对结果的影响。在这里插入图片描述
其中SIN意思是程式化的ImageNet数据集,也就是一张图像有多个标签。从表中可以看到,当改变不同的组件时,多头分类器每个头部的分类偏差会受到显著影响,但是其整体的测试精度却几乎不会受到影响。
四、在这里插入图片描述
第四个实验室在ImageNet-9数据集上进行的(一共9类数据)。其中BG-Gap表示分类与背景信息的相关性,Mixed-Same表示不同的class object,相同的background,mixed-rand表示相同的object不同的background。可以看到虽然mixed-rand使得分类器几乎不会对背景信息产生应答(对应表格中最后一列),然而在分类测试精度上依然要差于本文方法,这就很神奇。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值