原理推导
1.互信息
2.信息瓶颈
也就是我们假设从输出到特征空间再到输出是一个马尔科夫过程
现在我们的目标就是让Z充分包含Y的信息,同时为了压缩效果是X和Z的互信息尽可能少
也就是:
引入拉格朗日乘子,即最大化目标:
将公式进一步展开:
下面介绍变分信息瓶颈:
现在我们对X,Z,Y的马尔科夫链进一步假设:
这其实是符合神经网络模型的:
因此,我们有:
我们可以进一步得到:
因为 p(y|z) 无法直接计算,假设 q(y|z) 是 p(y|z) 的变分近似,也就是我们模型中的encoder
利用KL散度非负的特性,可得:
我们对I(Z,Y)化简:
将 p(y,z) 写成 p(y,z)=p(x)p(y|x)p(z|x) ,可以得到新的下界:
现在我们考虑化简I(Z,X)
同理,让r(z)作为p(z)的变分近似:
因此,可得:
现在我们代入上述不等式对信息瓶颈的公式进行化简:
在实际计算中,由于我们只能对样本采样来拟合分布,因此:
Xn,Yn也就是数据集中的输入和输出的样本
进一步代入公式:
我们将输入X,通过encoder得到z的概率分布,然后进行采样,即:
考虑到“采样”这个操作是随机的,我们无法直接对它求导,于是进行参数重整化
最终,我们最小化:
其中:
下面展示一下怎么从代码层面进行实现(主要讲一下核心的函数):
KL散度函数
# 定义函数,接收两个分布的参数
def KL_between_normals(q_distr, p_distr):
# 1. 解包输入参数
# mu_q, sigma_q 的形状通常是 [batch_size, k],k是潜变量维度
# sigma_q 是标准差 (standard deviation),而不是方差 (variance)
mu_q, sigma_q = q_distr
mu_p, sigma_p = p_distr
# 2. 获取潜变量的维度 k
k = mu_q.size(1)
# 3. 计算均值差的平方项:(μ_p - μ_q)²
# 这一步是逐元素计算的
mu_diff = mu_p - mu_q
mu_diff_sq = torch.mul(mu_diff, mu_diff) # 等价于 mu_diff ** 2
# 4. 计算两个分布的对数行列式 (log-determinant)
# 对于对角协方差矩阵 Σ = diag(σ₁², σ₂², ...),其行列式 det(Σ) = Π σᵢ²
# 对数行列式 log(det(Σ)) = log(Π σᵢ²) = Σ log(σᵢ²) = Σ 2*log(σᵢ)
# torch.clamp 是为了数值稳定性,防止 sigma 变为 0 导致 log(0) = -inf
logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1)
logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1)
# 5. 计算公式中的两个主要分数项
# fs 对应于公式中的 Σ [ (σ_q²/σ_p²) + ((μ_p - μ_q)²/σ_p²) ]
# torch.div 是逐元素除法
# torch.sum(..., dim=1) 是沿着维度k进行求和,得到每个样本的总和
# (sigma_q ** 2) 是方差 σ²
term1 = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1) # 对应 tr(Σ_p⁻¹ * Σ_q)
term2 = torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1) # 对应 (μ_p - μ_q)ᵀ Σ_p⁻¹ (μ_p - μ_q)
fs = term1 + term2
# 6. 组合所有部分
# 这里计算的是 2 * KL(q || p)
# 2*KL = Σ [ σ_q²/σ_p² + (μ_p-μ_q)²/σ_p² - 1 - log(σ_q²) + log(σ_p²) ]
# = (fs) - Σ(1) + Σ(log(σ_p²)) - Σ(log(σ_q²))
# = fs - k + logdet_sigma_p - logdet_sigma_q
two_kl = fs - k + logdet_sigma_p - logdet_sigma_q
# 7. 返回最终的 KL 散度值
# 将结果除以2,得到最终的KL散度。
# 输出的张量形状为 [batch_size],每个元素是对应样本的KL散度值
return two_kl * 0.5
解释:
KL数学解释:
VIB模型:
class VIB(nn.Module):
def __init__(self, X_dim, y_dim, dimZ=256, beta=1e-3, num_samples=10):
super().__init__()
# 1. 保存超参数
self.beta = beta # VIB损失函数中的权重 β
self.dimZ = dimZ # 潜变量 Z 的维度
self.num_samples = num_samples # 蒙特卡洛采样次数,用于估算期望
# 2. 定义编码器
# 这是一个简单的多层感知机 (MLP)
self.encoder = nn.Sequential(nn.Linear(in_features=X_dim, out_features=1024),
nn.ReLU(),
nn.Linear(in_features=1024, out_features=1024),
nn.ReLU(),
# 关键: 输出维度是 2 * dimZ
# 一半用于均值(mu),一半用于标准差(sigma)
nn.Linear(in_features=1024, out_features=2 * self.dimZ))
# 3. 定义解码器
# 论文中使用了简单的逻辑回归,这里就是一个线性层
# 它接收潜变量 Z (dimZ),输出预测 Y 的 logits (y_dim)
self.decoder_logits = nn.Linear(in_features=self.dimZ, out_features=y_dim)
def gaussian_noise(self, num_samples, K):
# works with integers as well as tuples
return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).cuda()
def sample_prior_Z(self, num_samples):
return self.gaussian_noise(num_samples=num_samples, K=self.dimZ)
def encoder_result(self, batch):
encoder_output = self.encoder(batch)
mu = encoder_output[:, :self.dimZ]
sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:])
return mu, sigma
def sample_encoder_Z(self, num_samples, batch):
batch_size = batch.size()[0]
mu, sigma = self.encoder_result(batch)
return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ)
def forward(self, batch_x):
batch_size = batch_x.size()[0]
# sample from encoder
encoder_Z_distr = self.encoder_result(batch_x)
to_decoder = self.sample_encoder_Z(num_samples=self.num_samples, batch=batch_x)
decoder_logits_mean = torch.mean(self.decoder_logits(to_decoder), dim=0)
return decoder_logits_mean
def batch_loss(self, num_samples, batch_x, batch_y):
batch_size = batch_x.size()[0]
# --- 第一部分:计算 KL 散度项 I(Z;X) ---
# 1. 定义先验分布 N(0, I)
prior_Z_distr = torch.zeros(batch_size, self.dimZ).cuda(), torch.ones(batch_size, self.dimZ).cuda()
# 2. 获取编码器产生的后验分布 p(z|x)
encoder_Z_distr = self.encoder_result(batch_x)
# 3. 计算 KL(p(z|x) || p(z)) 并对批次取平均
I_ZX_bound = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr))
# --- 第二部分:计算交叉熵项 -I(Z;Y) ---
# 1. 使用重参数化技巧从 p(z|x) 中采样Z
# to_decoder 形状: (num_samples, batch_size, dimZ)
to_decoder = self.sample_encoder_Z(num_samples=self.num_samples, batch=batch_x)
# 2. 将Z送入解码器得到预测
# decoder_logits 形状: (num_samples, batch_size, y_dim)
decoder_logits = self.decoder_logits(to_decoder)
# 3. 准备计算交叉熵损失
# nn.CrossEntropyLoss期望的输入形状是 (N, C, ...),其中N是批大小,C是类别数
# 当前形状是 (num_samples, batch_size, y_dim)
# 我们把它变成 (batch_size, y_dim, num_samples) 来匹配PyTorch的要求
decoder_logits = decoder_logits.permute(1, 2, 0)
# 4. 计算每个样本、每个采样的交叉熵损失
# reduce=False: 不对损失求和或求平均,返回每个元素的损失
loss = nn.CrossEntropyLoss(reduce=False)
# batch_y 形状是 (batch_size),需要扩展成 (batch_size, num_samples)
# 以便和 (batch_size, y_dim, num_samples) 的预测logits对应
cross_entropy_loss = loss(decoder_logits, batch_y[:, None].expand(-1, num_samples))
# 5. 通过对采样维度求平均来估算期望 E[log q(y|z)]
cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
# 6. 对批次求平均,得到最终的交叉熵损失项
minusI_ZY_bound = torch.mean(cross_entropy_loss_montecarlo, dim=0)
# --- 合并两部分,得到最终的总损失 ---
# J_IB = -I(Z;Y) + β * I(Z;X)
total_loss = minusI_ZY_bound + self.beta * I_ZX_bound
# 返回总损失和两个分量,方便监控训练过程
return total_loss, -minusI_ZY_bound, I_ZX_bound
参考代码链接:https://github.com/makezur/VIB_pytorch/blob/master/VIB.ipynb