变分信息瓶颈(VIB)

原理推导

1.互信息

2.信息瓶颈

也就是我们假设从输出到特征空间再到输出是一个马尔科夫过程

现在我们的目标就是让Z充分包含Y的信息,同时为了压缩效果是X和Z的互信息尽可能少

也就是:

引入拉格朗日乘子,即最大化目标:

将公式进一步展开:

下面介绍变分信息瓶颈:

现在我们对X,Z,Y的马尔科夫链进一步假设:

这其实是符合神经网络模型的:

因此,我们有:

我们可以进一步得到:

因为 p(y|z) 无法直接计算,假设 q(y|z) 是 p(y|z) 的变分近似,也就是我们模型中的encoder

利用KL散度非负的特性,可得:

我们对I(Z,Y)化简:

将 p(y,z) 写成 p(y,z)=p(x)p(y|x)p(z|x) ,可以得到新的下界:

现在我们考虑化简I(Z,X)

同理,让r(z)作为p(z)的变分近似:

因此,可得:

现在我们代入上述不等式对信息瓶颈的公式进行化简:

在实际计算中,由于我们只能对样本采样来拟合分布,因此:

Xn,Yn也就是数据集中的输入和输出的样本

进一步代入公式:

我们将输入X,通过encoder得到z的概率分布,然后进行采样,即:

考虑到“采样”这个操作是随机的,我们无法直接对它求导,于是进行参数重整化

最终,我们最小化:

其中:

下面展示一下怎么从代码层面进行实现(主要讲一下核心的函数):

KL散度函数

# 定义函数,接收两个分布的参数
def KL_between_normals(q_distr, p_distr):
    # 1. 解包输入参数
    # mu_q, sigma_q 的形状通常是 [batch_size, k],k是潜变量维度
    # sigma_q 是标准差 (standard deviation),而不是方差 (variance)
    mu_q, sigma_q = q_distr
    mu_p, sigma_p = p_distr
    
    # 2. 获取潜变量的维度 k
    k = mu_q.size(1)
    
    # 3. 计算均值差的平方项:(μ_p - μ_q)²
    # 这一步是逐元素计算的
    mu_diff = mu_p - mu_q
    mu_diff_sq =  torch.mul(mu_diff, mu_diff) # 等价于 mu_diff ** 2
    
    # 4. 计算两个分布的对数行列式 (log-determinant)
    # 对于对角协方差矩阵 Σ = diag(σ₁², σ₂², ...),其行列式 det(Σ) = Π σᵢ²
    # 对数行列式 log(det(Σ)) = log(Π σᵢ²) = Σ log(σᵢ²) = Σ 2*log(σᵢ)
    # torch.clamp 是为了数值稳定性,防止 sigma 变为 0 导致 log(0) = -inf
    logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1)
    logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1)
    
    # 5. 计算公式中的两个主要分数项
    # fs 对应于公式中的 Σ [ (σ_q²/σ_p²) + ((μ_p - μ_q)²/σ_p²) ]
    # torch.div 是逐元素除法
    # torch.sum(..., dim=1) 是沿着维度k进行求和,得到每个样本的总和
    # (sigma_q ** 2) 是方差 σ²
    term1 = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1)  # 对应 tr(Σ_p⁻¹ * Σ_q)
    term2 = torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1) # 对应 (μ_p - μ_q)ᵀ Σ_p⁻¹ (μ_p - μ_q)
    fs = term1 + term2
    
    # 6. 组合所有部分
    # 这里计算的是 2 * KL(q || p)
    # 2*KL = Σ [ σ_q²/σ_p² + (μ_p-μ_q)²/σ_p² - 1 - log(σ_q²) + log(σ_p²) ]
    #       = (fs) - Σ(1) + Σ(log(σ_p²)) - Σ(log(σ_q²))
    #       = fs - k + logdet_sigma_p - logdet_sigma_q
    two_kl =  fs - k + logdet_sigma_p - logdet_sigma_q
    
    # 7. 返回最终的 KL 散度值
    # 将结果除以2,得到最终的KL散度。
    # 输出的张量形状为 [batch_size],每个元素是对应样本的KL散度值
    return two_kl * 0.5

解释:

KL数学解释:

VIB模型:

class VIB(nn.Module):
    def __init__(self, X_dim, y_dim, dimZ=256, beta=1e-3, num_samples=10):
        super().__init__()
        
        # 1. 保存超参数
        self.beta = beta               # VIB损失函数中的权重 β
        self.dimZ = dimZ             # 潜变量 Z 的维度
        self.num_samples = num_samples # 蒙特卡洛采样次数,用于估算期望
        
        # 2. 定义编码器
        # 这是一个简单的多层感知机 (MLP)
        self.encoder = nn.Sequential(nn.Linear(in_features=X_dim, out_features=1024),
                                     nn.ReLU(),
                                     nn.Linear(in_features=1024, out_features=1024),
                                     nn.ReLU(),
                                     # 关键: 输出维度是 2 * dimZ
                                     # 一半用于均值(mu),一半用于标准差(sigma)
                                     nn.Linear(in_features=1024, out_features=2 * self.dimZ))
        
        # 3. 定义解码器
        # 论文中使用了简单的逻辑回归,这里就是一个线性层
        # 它接收潜变量 Z (dimZ),输出预测 Y 的 logits (y_dim)
        self.decoder_logits = nn.Linear(in_features=self.dimZ, out_features=y_dim)


    def gaussian_noise(self, num_samples, K):
        # works with integers as well as tuples   
        return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).cuda()
           
    def sample_prior_Z(self, num_samples):
        return self.gaussian_noise(num_samples=num_samples, K=self.dimZ)

    def encoder_result(self, batch):
        encoder_output = self.encoder(batch)
        
        mu = encoder_output[:, :self.dimZ]
        sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:])
        
        return mu, sigma
    
    def sample_encoder_Z(self, num_samples, batch): 
        batch_size = batch.size()[0]
        mu, sigma = self.encoder_result(batch)
        
        return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ)
    
    
    def forward(self, batch_x):
        
        batch_size = batch_x.size()[0]
        
        # sample from encoder
        encoder_Z_distr = self.encoder_result(batch_x)  
        to_decoder = self.sample_encoder_Z(num_samples=self.num_samples, batch=batch_x)

        decoder_logits_mean = torch.mean(self.decoder_logits(to_decoder), dim=0)
                
        return decoder_logits_mean
        
    def batch_loss(self, num_samples, batch_x, batch_y):
        batch_size = batch_x.size()[0]
        
        # --- 第一部分:计算 KL 散度项 I(Z;X) ---
        # 1. 定义先验分布 N(0, I)
        prior_Z_distr = torch.zeros(batch_size, self.dimZ).cuda(), torch.ones(batch_size, self.dimZ).cuda()
        # 2. 获取编码器产生的后验分布 p(z|x)
        encoder_Z_distr = self.encoder_result(batch_x)
        
        # 3. 计算 KL(p(z|x) || p(z)) 并对批次取平均
        I_ZX_bound = torch.mean(KL_between_normals(encoder_Z_distr, prior_Z_distr))
        
        # --- 第二部分:计算交叉熵项 -I(Z;Y) ---
        # 1. 使用重参数化技巧从 p(z|x) 中采样Z
        # to_decoder 形状: (num_samples, batch_size, dimZ)
        to_decoder = self.sample_encoder_Z(num_samples=self.num_samples, batch=batch_x)
        
        # 2. 将Z送入解码器得到预测
        # decoder_logits 形状: (num_samples, batch_size, y_dim)
        decoder_logits = self.decoder_logits(to_decoder)
        
        # 3. 准备计算交叉熵损失
        # nn.CrossEntropyLoss期望的输入形状是 (N, C, ...),其中N是批大小,C是类别数
        # 当前形状是 (num_samples, batch_size, y_dim)
        # 我们把它变成 (batch_size, y_dim, num_samples) 来匹配PyTorch的要求
        decoder_logits = decoder_logits.permute(1, 2, 0)
    
        # 4. 计算每个样本、每个采样的交叉熵损失
        # reduce=False: 不对损失求和或求平均,返回每个元素的损失
        loss = nn.CrossEntropyLoss(reduce=False)
        # batch_y 形状是 (batch_size),需要扩展成 (batch_size, num_samples)
        # 以便和 (batch_size, y_dim, num_samples) 的预测logits对应
        cross_entropy_loss = loss(decoder_logits, batch_y[:, None].expand(-1, num_samples))
        
        # 5. 通过对采样维度求平均来估算期望 E[log q(y|z)]
        cross_entropy_loss_montecarlo = torch.mean(cross_entropy_loss, dim=-1)
        
        # 6. 对批次求平均,得到最终的交叉熵损失项
        minusI_ZY_bound = torch.mean(cross_entropy_loss_montecarlo, dim=0)
                
        # --- 合并两部分,得到最终的总损失 ---
        # J_IB = -I(Z;Y) + β * I(Z;X)
        total_loss = minusI_ZY_bound + self.beta * I_ZX_bound
        
        # 返回总损失和两个分量,方便监控训练过程
        return total_loss, -minusI_ZY_bound, I_ZX_bound

参考代码链接:https://github.com/makezur/VIB_pytorch/blob/master/VIB.ipynb

`.vib` 文件是一种 **振动数据文件格式**,通常用于存储设备、机械或结构的振动信号数据。这种文件格式常见于 **工业监测系统、旋转机械析、预测性维护(PdM)、状态监测(CBM)** 等领域。 --- ## ✅ 一、什么是 `.vib` 文件? `.vib` 文件通常包含: - 振动波形数据(时域信号) - 频率谱数据(频域信号) - 时间戳 - 采样率 - 传感器信息(如加速度计、速度计) - 设备信息(如轴承、齿轮、电机) - 单位(如 mm/s、g、m/s²) 这些数据通常由 **振动传感器**(如加速度计)采集,并通过析工具进行频谱析、包络析、FFT、峭度析等,用于检测设备故障(如轴承损坏、不对中、不平衡、齿轮断裂等)。 --- ## ✅ 二、常见的 `.vib` 文件来源 | 来源 | 说明 | |------|------| | **SKF @ptitude** | SKF 提供的状态监测软件,支持 `.vib` 文件格式 | | **Pronax CBM** | 用于旋转机械状态监测,生成 `.vib` 文件 | | **Brüel & Kjær** | 振动析仪生成 `.vib` 文件 | | **通用工业监测系统** | 一些自定义的振动数据采集系统也使用 `.vib` 格式 | --- ## ✅ 三、如何打开 `.vib` 文件? ### 1. 使用 SKF @ptitude - 官网: [https://www.skf.com](https://www.skf.com) - 说明:SKF 提供的专用振动析软件,支持 `.vib` 文件导入、析、报告生成。 - 支持功能: - 时域/频域析 - 包络析 - 轴心轨迹图 - 趋势析 ### 2. 使用 MATLAB 你可以使用 MATLAB 的 `fread` 或 `textscan` 函数读取 `.vib` 文件(前提是它是 ASCII 格式)。 ```matlab filename = 'example.vib'; fid = fopen(filename, 'r'); data = fread(fid, Inf, 'float32'); % 假设是 float32 格式 fclose(fid); plot(data); title('Vibration Data'); ``` > 注意:`.vib` 文件可能是二进制格式,需根据文件结构解析。 ### 3. 使用 Python 读取 `.vib` 文件(假设是二进制) ```python import numpy as np with open('example.vib', 'rb') as f: data = np.fromfile(f, dtype=np.float32) # 假设是 float32 类型 print(data.shape) print(data[:10]) # 打印前10个数据点 ``` 你也可以使用 `struct` 模块按字节解析: ```python import struct with open('example.vib', 'rb') as f: while chunk := f.read(4): # 每次读取4字节(float32) value = struct.unpack('f', chunk) print(value[0]) ``` ### 4. 使用开源工具(如 Octave、SciLab) - **GNU Octave**:MATLAB 的开源替代品,支持类似语法。 - **SciLab**:另一个科学计算工具,支持读取二进制文件。 --- ## ✅ 四、如何解析 `.vib` 文件结构? 由于 `.vib` 文件没有统一的公开格式标准,你需要知道其结构才能正确解析: | 信息 | 说明 | |------|------| | 文件头(Header) | 可能包含采样率、通道数、时间戳等元数据 | | 数据块(Data Block) | 振动数据,可能是 float32 或 int16 类型 | | 数据长度 | 通常由采样率 × 持续时间决定 | | 多通道 | 每个通道可能按顺序存储,或交错存储 | --- ## ✅ 五、示例:`.vib` 文件内容结构(假设为 ASCII) ``` # Vibration Data File SampleRate: 10000 Channel: 1 Units: mm/s StartTime: 2024-01-01T12:00:00Z Data: 0.001 0.002 0.003 ... ``` 你可以使用 Python 读取并绘制: ```python import numpy as np import matplotlib.pyplot as plt with open('example.vib', 'r') as f: lines = f.readlines() # 跳过头部信息 data_lines = [line.strip() for line in lines if line[0].isdigit()] data = np.array(data_lines, dtype=np.float32) plt.plot(data) plt.title('Vibration Signal') plt.xlabel('Sample') plt.ylabel('Amplitude (mm/s)') plt.show() ``` --- ## ✅ 六、如何将 `.vib` 文件转换为其他格式? ### 1. 转换为 `.csv` ```python np.savetxt('output.csv', data, delimiter=',') ``` ### 2. 转换为 `.wav`(音频格式) ```python from scipy.io.wavfile import write # 假设采样率为 10000 Hz write('output.wav', 10000, data.astype(np.float32)) ``` ### 3. 转换为 `.mat`(MATLAB 文件) ```python from scipy.io import savemat savemat('output.mat', {'vibration_data': data}) ``` --- ## ✅ 七、总结 | 功能 | 工具/方法 | |------|-----------| | 打开 `.vib` 文件 | SKF @ptitude、MATLAB、Python | | 解析 `.vib` 文件 | Python(`numpy`, `struct`) | | 绘制 `.vib` 数据 | Matplotlib | | 转换 `.vib` 文件 | CSV、WAV、MAT | | 析 `.vib` 数据 | FFT、包络谱、时频析、峭度析 | --- ##
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值