互信息(Mutual Information)
两个随机变量 XXX 和 YYY 的互信息定义为
I(X,Y)=∬p(x,y)logp(x,y)p(x)p(y)dxdy I(X,Y) = \iint p(x,y) \log \frac{p(x,y)}{p(x)p(y)}\mathrm{d}x\mathrm{d}y I(X,Y)=∬p(x,y)logp(x)p(y)p(x,y)dxdy
这是用来衡量两个随机变量相关性的一个量,显而易见若 XXX 和 YYY 独立分布,则 p(x,y)=p(x)p(y)p(x, y) = p(x)p(y)p(x,y)=p(x)p(y) ,它们的互信息是0。
假设深度神经网络的输入是 XXX , 输出是 YYY ,中间的某些层的特征表示我们把它当成是一种随机编码 YYY ,如此,这个网络结构可以看成是
X→Z→Y X \rightarrow Z \rightarrow Y X→Z→Y
我们希望一个好的特征表示 ZZZ 应该满足:
- ZZZ 尽可能是对 XXX 的压缩。
- ZZZ 对于预测 YYY 应该具有最大的信息量。
把这两个条件写成数学公式就是:
maxI(Z,Y;θ)s.t.minI(X,Z;θ) \max I(Z, Y;\theta)\qquad \mathrm{s.t.} \min I(X, Z;\theta) maxI(Z,Y;θ)s.t.minI(X,Z;θ)
其中 θ\thetaθ 是网络的参数。这个目标的约束条件不直观,引入信息约束 IcI_cIc ,可以改写成:
maxI(Z,Y;θ)s.t.,I(X,Z;θ)≤Ic \max I(Z, Y;\theta)\qquad \mathrm{s.t.} , I(X, Z;\theta) \le I_c maxI(Z,Y;θ)s.t.,I(X,Z;θ)≤Ic
引入拉格朗日乘子,目标转变为最大化:
RIB=I(Z,Y;θ)−βI(X,Z;θ) R_{IB} = I(Z, Y;\theta) - \beta I(X, Z;\theta) RIB=I(Z,Y;θ)−βI(X,Z;θ)
这个 RIBR_{IB}RIB 就是信息瓶颈[1],它的意义就是要学习到一种编码,能够对于输出的预测具有最大的表达能力,同时对于输入信息具备最大的压缩能力。
两个互信息分别可以展开成,
I(Z,Y)=∫p(y,z)logp(y,z)p(y)p(z)dydz=∫p(y,z)logp(y∣z)p(y)dydz \begin{aligned} I(Z, Y) &= \int p(y,z)\log \frac{p(y,z)}{p(y)p(z)}\mathrm{d}y\mathrm{d}z \\ &= \int p(y,z) \log \frac{p(y|z)}{p(y)}\mathrm{d}y\mathrm{d}z \end{aligned} I(Z,Y)=∫p(y,z)logp(y)p(z)p(y,z)dydz=∫p(y,z)logp(y)p(y∣z)dydz
I(X,Z)=∫p(x,z)logp(x,z)p(x)p(z)dxdz=∫p(x,z)logp(z∣x)p(z)dxdz \begin{aligned} I(X, Z) &= \int p(x,z)\log \frac{p(x,z)}{p(x)p(z)}\mathrm{d}x\mathrm{d}z \\ &= \int p(x,z) \log \frac{p(z|x)}{p(z)}\mathrm{d}x\mathrm{d}z \end{aligned} I(X,Z)=∫p(x,z)logp(x)p(z)p(x,z)dxdz=∫p(x,z)logp(z)p(z∣x)dxdz
变分信息瓶颈(Variational Information Bottleneck)
基于此再假设 XXX , YYY , ZZZ 之间的关系满足如下的马尔可夫链(即 ZZZ 不能直接由 YYY 决定):
Y↔X↔Z Y \leftrightarrow X \leftrightarrow Z Y↔X↔Z
联合分布可以分解成:
p(X,Y,Z)=p(X)p(Y∣X)p(Z∣X) p(X,Y,Z)=p(X)p(Y|X)p(Z|X) p(X,Y,Z)=p(X)p(Y∣X)p(Z∣X)
这个马尔可夫链上的分布可以完全由我们的编码器 p(Z∣X)p(Z|X)p(Z∣X) 和马尔可夫链本身的约束得到,利用马尔可夫链,还可以得到,
p(y∣z)=∫p(x,y∣z)dx=∫p(y∣x)p(x∣z)dx=∫p(y∣x)p(z∣x)p(x)p(z)dx p(y|z) = \int p(x, y|z) \mathrm{d}x = \int p(y|x)p(x|z) \mathrm{d}x = \int \frac{p(y|x)p(z|x)p(x)}{p(z)}\mathrm{d}x p(y∣z)=∫p(x,y∣z)dx=∫p(y∣x)p(x∣z)dx=∫p(z)p(y∣x)p(z∣x)p(x)dx
其中 p(z∣x)p(z|x)p(z∣x) 是编码器, 因此 p(y∣z)p(y|z)p(y∣z) 完全由我们的编码器和这个马尔可夫链本身所决定。
对于 I(Z,Y)I(Z, Y)I(Z,Y)
因为 p(y∣z)p(y|z)p(y∣z) 无法直接计算,假设 q(y∣z)q(y|z)q(y∣z) 是 p(y∣z)p(y|z)p(y∣z) 的变分近似(我们用模型来计算它,就是我们的解码器模块),利用KL散度非负的特性:
KL[p(y∣z),q(y∣z)]≥0⇒∫p(y∣z)logp(y∣z)dy≥∫p(y∣z)logq(y∣z)dy \mathrm{KL}\left[p(y|z), q(y|z)\right] \ge 0 \Rightarrow \int p(y|z) \log p(y|z) \mathrm{d}y \ge \int p(y|z) \log q(y|z) \mathrm{d}y KL[p(y∣z),q(y∣z)]≥0⇒∫p(y∣z)logp(y∣z)dy≥∫p(y∣z)logq(y∣z)dy
因此有,
I(Z,Y)≥∫p(y,z)logq(y∣z)p(y)dydz=∫p(y,z)logq(y∣z)dydz−∫p(y)logp(y)dy=∫p(y,z)logq(y∣z)dydz+H(Y) \begin{aligned} I(Z, Y) &\ge \int p(y,z) \log \frac{q(y|z)}{p(y)}\mathrm{d}y\mathrm{d}z \\ &= \int p(y,z) \log q(y|z)\mathrm{d}y\mathrm{d}z - \int p(y) \log p(y)\mathrm{d}y \\ &= \int p(y,z) \log q(y|z)\mathrm{d}y\mathrm{d}z + H(Y) \end{aligned} I(Z,Y)≥∫p(y,z)logp(y)q(y∣z)dydz=∫p(y,z)logq(y∣z)dydz−∫p(y)logp(y)dy=∫p(y,z)logq(y∣z)dydz+H(Y)
H(Y)H(Y)H(Y) 是 标签 yyy 的概率分布的熵,这个和我们的优化过程无关,可以被忽略掉,
I(Z,Y)≥∫p(y,z)logq(y∣z)dydz I(Z, Y) \ge \int p(y,z) \log q(y|z)\mathrm{d}y\mathrm{d}z I(Z,Y)≥∫p(y,z)logq(y∣z)dydz
将 p(y,z)p(y,z)p(y,z) 写成 p(y,z)=p(x)p(y∣x)p(z∣x)p(y,z) = p(x)p(y|x)p(z|x)p(y,z)=p(x)p(y∣x)p(z∣x) ,可以得到新的下界:
I(Z,Y)≥∫p(x)p(y∣x)p(z∣x)logq(y∣z)dxdydz I(Z,Y) \ge \int p(x)p(y|x)p(z|x)\log q(y|z) \mathrm{d}x\mathrm{d}y\mathrm{d}z I(Z,Y)≥∫p(x)p(y∣x)p(z∣x)logq(y∣z)dxdydz
对于 I(Z,X)I(Z, X)I(Z,X)
对于 XXX 和 ZZZ 之间的互信息,
I(Z,X)=∫p(x,z)logp(z∣x)p(z)dxdz=∫p(x,z)logp(z∣x)dxdz−∫p(z)logp(z)dz I(Z, X) = \int p(x,z) \log \frac{p(z|x)}{p(z)}\mathrm{d}x\mathrm{d}z = \int p(x,z) \log p(z|x) \mathrm{d}x\mathrm{d}z - \int p(z)\log p(z)\mathrm{d}z I(Z,X)=∫p(x,z)logp(z)p(z∣x)dxdz=∫p(x,z)logp(z∣x)dxdz−∫p(z)logp(z)dz
计算 ZZZ 的边际分布 p(z)=∫p(z∣x)p(x)dxp(z) = \int p(z|x)p(x)\mathrm{d}xp(z)=∫p(z∣x)p(x)dx 不是一件容易的事情,因此,让 r(z)r(z)r(z) 作为这个边际分布的一个变分近似,利用 KL[p(z),r(z)]≥0\mathrm{KL}\left[p(z), r(z)\right] \ge 0KL[p(z),r(z)]≥0 ,可以得到:
∫p(z)logp(z)dz≥∫p(z)logr(z)dz \int p(z) \log p(z) \mathrm{d}z \ge \int p(z) \log r(z) \mathrm{d}z ∫p(z)logp(z)dz≥∫p(z)logr(z)dz
因此,
I(Z,X)≤∫p(x)p(z∣x)logp(z∣x)r(z)dxdz \begin{aligned} I(Z, X) \le \int p(x)p(z|x)\log \frac{p(z|x)}{r(z)}\mathrm{d}x\mathrm{d}z \end{aligned} I(Z,X)≤∫p(x)p(z∣x)logr(z)p(z∣x)dxdz
Variational Information Bottleneck (RVIBR_{VIB}RVIB)
结合 I(Z,Y)I(Z, Y)I(Z,Y) 的下界和 I(Z,X)I(Z,X)I(Z,X) 的上界,可以得到,
RIB=I(Z,Y)−βI(Z,X)≥∫p(x)p(y∣x)p(z∣x)logq(y∣z)dxdydz−∫p(x)p(z∣x)logp(z∣x)r(z)dxdz=RVIB R_{IB} = I(Z,Y) - \beta I(Z, X) \ge \int p(x)p(y|x)p(z|x)\log q(y|z) \mathrm{d}x\mathrm{d}y\mathrm{d}z - \int p(x)p(z|x)\log \frac{p(z|x)}{r(z)}\mathrm{d}x\mathrm{d}z = R_{VIB} RIB=I(Z,Y)−βI(Z,X)≥∫p(x)p(y∣x)p(z∣x)logq(y∣z)dxdydz−∫p(x)p(z∣x)logr(z)p(z∣x)dxdz=RVIB
在实际计算中,将 p(x,y)p(x,y)p(x,y) 用经验分布 p(x,y)=1N∑i=1Nδxn(x)δyn(y)p(x,y) = \frac{1}{N}\sum_{i=1}^N\delta_{x_n}(x)\delta_{y_n}(y)p(x,y)=N1∑i=1Nδxn(x)δyn(y) 来替代,可以得到
RVIB=1N∑i=1N[∫p(z∣xn)logq(yn∣z)−βp(z∣xn)logp(z∣xn)r(z)dz] R_{VIB} = \frac{1}{N}\sum_{i=1}^N\left[\int p(z|x_n)\log q(y_n|z) - \beta p(z|x_n)\log \frac{p(z|x_n)}{r(z)}\mathrm{d}z\right] RVIB=N1i=1∑N[∫p(z∣xn)logq(yn∣z)−βp(z∣xn)logr(z)p(z∣xn)dz]
假设编码器是类似VAE的结构 p(z∣x)=N(z∣feμ(x),feΣ(x))p(z|x) = \mathcal{N}(z|f^{\mu}_e(x),f^{\Sigma}_e(x))p(z∣x)=N(z∣feμ(x),feΣ(x)) , fef_efe 是编码器网络,可以利用重参数化技巧,得到
p(z∣x)dz=p(ϵ)dϵ p(z|x)\mathrm{d}z = p(\epsilon)\mathrm{d}\epsilon p(z∣x)dz=p(ϵ)dϵ
其中 ϵ\epsilonϵ 是高斯随机变量。
假设 p(z∣x)p(z|x)p(z∣x) 和 r(z)r(z)r(z) 的选择使得我们能够计算KL散度,于是,最大化变分信息瓶颈在实际计算中可以转变为最小化如下目标:
JIB=1N∑n=1NEϵ∼p(ϵ)[−logq(yn∣f(xn,ϵ))]+βKL[p(z∣xn)∣r(z)] J_{IB} = \frac{1}{N}\sum_{n=1}^N\mathbb{E}_{\epsilon \sim p(\epsilon)}[-\log q(y_n|f(x_n,\epsilon))] + \beta \mathrm{KL}[p(z|x_n)|r(z)] JIB=N1n=1∑NEϵ∼p(ϵ)[−logq(yn∣f(xn,ϵ))]+βKL[p(z∣xn)∣r(z)]
参考文献
- Tishby, Naftali, Fernando C. Pereira, and William Bialek. “The information bottleneck method.” arXiv preprint physics/0004057 (2000).
- Maximilian Igl, Kamil Ciosek, Yingzhen Li, Sebastian Tschiatschek, Cheng Zhang, Sam Devlin and Katja Hofmann. “Generalization in Reinforcement Learning with Selective Noise Injection and Information Bottleneck.” arXiv preprint cs.LG/1910.12911 (2019).
- Alemi, Alexander A., et al. “Deep variational information bottleneck.” arXiv preprint arXiv:1612.00410 (2016).