变分信息瓶颈(VIB)

原理推导

1.互信息

2.信息瓶颈

也就是我们假设从输出到特征空间再到输出是一个马尔科夫过程

现在我们的目标就是让Z充分包含Y的信息,同时为了压缩效果是X和Z的互信息尽可能少

也就是:

引入拉格朗日乘子,即最大化目标:

将公式进一步展开:

下面介绍变分信息瓶颈:

现在我们对X,Z,Y的马尔科夫链进一步假设:

这其实是符合神经网络模型的:

因此,我们有:

我们可以进一步得到:

因为 p(y|z) 无法直接计算,假设 q(y|z) 是 p(y|z) 的变分近似,也就是我们模型中的encoder

利用KL散度非负的特性,可得:

我们对I(Z,Y)化简:

将 p(y,z) 写成 p(y,z)=p(x)p(y|x)p(z|x) ,可以得到新的下界:

现在我们考虑化简I(Z,X)

同理,让r(z)作为p(z)的变分近似:

因此,可得:

现在我们代入上述不等式对信息瓶颈的公式进行化简:

在实际计算中,由于我们只能对样本采样来拟合分布,因此:

Xn,Yn也就是数据集中的输入和输出的样本

进一步代入公式:

我们将输入X,通过encoder得到z的概率分布,然后进行采样,即:

考虑到“采样”这个操作是随机的,我们无法直接对它求导,于是进行参数重整化

最终,我们最小化:

其中:

下面展示一下怎么从代码层面进行实现(主要讲一下核心的函数):

KL散度函数

# 定义函数,接收两个分布的参数
def KL_between_normals(q_distr, p_distr):
    # 1. 解包输入参数
    # mu_q, sigma_q 的形状通常是 [batch_size, k],k是潜变量维度
    # sigma_q 是标准差 (standard deviation),而不是方差 (variance)
    mu_q, sigma_q = q_distr
    mu_p, sigma_p = p_distr
    
    # 2. 获取潜变量的维度 k
    k = mu_q.size(1)
    
    # 3. 计算均值差的平方项:(μ_p - μ_q)²
    # 这一步是逐元素计算的
    mu_diff = mu_p - mu_q
    mu_diff_sq =  torch.mul(mu_diff, mu_diff) # 等价于 mu_diff ** 2
    
    # 4. 计算两个分布的对数行列式 (log-determinant)
    # 对于对角协方差矩阵 Σ = diag(σ₁², σ₂², ...),其行列式 det(Σ) = Π σᵢ²
    # 对数行列式 log(det(Σ)) = log(Π σᵢ²) = Σ log(σᵢ²) = Σ 2*log(σᵢ)
    # torch.clamp 是为了数值稳定性,防止 sigma 变为 0 导致 log(0) = -inf
    logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1)
    logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1)
    
    # 5. 计算公式中的两个主要分数项
    # fs 对应于公式中的 Σ [ (σ_q²/σ_p²) + ((μ_p - μ_q)²/σ_p²) ]
    # torch.div 是逐元素除法
    # torch.sum(..., dim=1) 是沿着维度k进行求和,得到每个样本的总和
    # (sigma_q ** 2) 是方差 σ²
    term1 = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1)  # 对应 tr(Σ_p⁻¹ * Σ_q)
    term2 = torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1) # 对应 (μ_p - μ_q)ᵀ Σ_p⁻¹ (μ_p - μ_q)
    fs = term1 + term2
    
    # 6. 组合所有部分
    # 这里计算的是 2 * KL(q || p)
    # 2*KL = Σ [ σ_q²/σ_p² + (μ_p-μ_q)²/σ_p² - 1 - log(σ_q²) + log(σ_p²) ]
    #       = (fs) - Σ(1) + Σ(log(σ_p²)) - Σ(log(σ_q²))
    #       = fs - k + logdet_sigma_p - logdet_sigma_q
    two_kl =  fs - k + logdet_sigma_p - logdet_sigma_q
    
    # 7. 返回最终的 KL 散度值
    # 将结果除以2,得到最终的KL散度。
    # 输出的张量形状为 [batch_size],每个元素是对应样本的KL散度值
    return two_kl * 0.5

解释:

KL数学解释:

VIB模型:

class VIB(nn.Module):
    def __init__(self, X_dim, y_dim, dimZ=256, beta=1e-3, num_samples=10):
        super().__init__()
        
        # 1. 保存超参数
        self.beta = beta               # VIB损失函数中的权重 β
        self.dimZ = dimZ             # 潜变量 Z 的维度
        self.num_samples = num_samples # 蒙特卡洛采样次数,用于估算期望
        
        # 2. 定义编码器
        # 这是一个简单的多层感知机 (MLP)
        self.encoder = nn.Sequential(nn.Linear(in_features=X_dim, out_features=1024),
                                     nn.ReLU(),
                                     nn.Linear(in_features=1024, out_features=1024),
                                     nn.ReLU(),
                                     # 关键: 输出维度是 2 * dimZ
                                     # 一半用于均值(mu),一半用于标准差(sigma)
                                     nn.Linear(in_features=1024, out_features=2 * self.dimZ))
        
        # 3. 定义解码器
        # 论文中使用了简单的逻辑回归,这里就是一个线性层
        # 它接收潜变量 Z (dimZ),输出预测 Y 的 logits (y_dim)
        self.decoder_logits = nn.Linear(in_features=self.dimZ, out_features=y_dim)


    def gaussian_noise(self, num_samples, K):
        # works with integers as well as tuples   
        return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).cuda()
           
    def sample_prior_Z(self, num_samples):
        return self.gaussian_noise(num_samples=num_samples, K=self.dimZ)

    def encoder_result(self, batch):
        encoder_output = self.encoder(batch)
        
        mu = encoder_output[:, :self.dimZ]
        sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:])
        
        return mu, sigma
    
    def sample_encoder_Z(self, num_samples, batch): 
        batch_size = batch.size()[0]
        mu, sigma = self.encoder_result(batch)
        
        return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ)
    
    
    def forward(self, batch_x):
        
        batch_size = batch_x.size()[0]
        
        # sample from encoder
        encoder_Z_distr = self.encoder_result(batch_x)  
        to_decoder = self.sample_encoder_Z(num_samples=self.num_samples, batch=batch_x)

        decoder_logits_mean = torch.mean(self.decoder_logits(to_decoder), dim=0)
                
        return decoder_logits_mean
        
    def batch_loss(self, num_samples, batch_x, batch_y):
        batch_size = batch_x.size()[0]
        
        # --- 第一部分:计算 KL 散度项 I(Z;X) ---
        # 1. 定义先验分布 N(0, I)
        prior_Z_distr = torch.zeros(batch_size, self.dimZ).cuda(), torch.ones(batch_size, self.dimZ).cuda()
        # 2. 获取编码器产生的后验分布 p(z|x)
        encoder_Z_distr = self.encoder_result(batch_x)
        
        # 3. 计算 KL(p(z|x) || p(z)) 并对批次取平均
        I_ZX_bound = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr))
        
        # --- 第二部分:计算交叉熵项 -I(Z;Y) ---
        # 1. 使用重参数化技巧从 p(z|x) 中采样Z
        # to_decoder 形状: (num_samples, batch_size, dimZ)
        to_decoder = self.sample_encoder_Z(num_samples=self.num_samples, batch=batch_x)
        
        # 2. 将Z送入解码器得到预测
        # decoder_logits 形状: (num_samples, batch_size, y_dim)
        decoder_logits = self.decoder_logits(to_decoder)
        
        # 3. 准备计算交叉熵损失
        # nn.CrossEntropyLoss期望的输入形状是 (N, C, ...),其中N是批大小,C是类别数
        # 当前形状是 (num_samples, batch_size, y_dim)
        # 我们把它变成 (batch_size, y_dim, num_samples) 来匹配PyTorch的要求
        decoder_logits = decoder_logits.permute(1, 2, 0)
    
        # 4. 计算每个样本、每个采样的交叉熵损失
        # reduce=False: 不对损失求和或求平均,返回每个元素的损失
        loss = nn.CrossEntropyLoss(reduce=False)
        # batch_y 形状是 (batch_size),需要扩展成 (batch_size, num_samples)
        # 以便和 (batch_size, y_dim, num_samples) 的预测logits对应
        cross_entropy_loss = loss(decoder_logits, batch_y[:, None].expand(-1, num_samples))
        
        # 5. 通过对采样维度求平均来估算期望 E[log q(y|z)]
        cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
        
        # 6. 对批次求平均,得到最终的交叉熵损失项
        minusI_ZY_bound = torch.mean(cross_entropy_loss_montecarlo, dim=0)
                
        # --- 合并两部分,得到最终的总损失 ---
        # J_IB = -I(Z;Y) + β * I(Z;X)
        total_loss = minusI_ZY_bound + self.beta * I_ZX_bound
        
        # 返回总损失和两个分量,方便监控训练过程
        return total_loss, -minusI_ZY_bound, I_ZX_bound

参考代码链接:https://github.com/makezur/VIB_pytorch/blob/master/VIB.ipynb

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值