基于Tensorflow1.15 实现的DBN故障诊断的完整程序

以下是基于DBN的故障诊断程序的完整实现,包含详细注释和可视化功能:

```python
# -*- coding: utf-8 -*-
"""
基于深度信念网络(DBN)的振动信号故障诊断系统
运行环境要求:
Python 3.6+
需要安装以下库:
pip install numpy matplotlib scikit-learn dbn==0.0.3 tensorflow==1.15.0
"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from dbn.tensorflow import SupervisedDBNClassification

# 模拟振动信号生成函数
def generate_vibration_data(samples_per_class=200, signal_length=1024):
    """
    生成包含3种状态的模拟振动信号:
    0 - 正常状态
    1 - 高频故障
    2 - 冲击故障
    """
    np.random.seed(42)
    data = []
    labels = []
    
    # 正常状态:高斯白噪声
    for _ in range(samples_per_class):
        signal = np.random.normal(0, 0.5, signal_length)
        data.append(signal)
        labels.append(0)
    
    # 高频故障:叠加高频正弦波
    for _ in range(samples_per_class):
        t = np.linspace(0, 2*np.pi, signal_length)
        noise = np.random.normal(0, 0.5, signal_length)
        signal = noise + 0.8*np.sin(100*t)
        data.append(signal)
        labels.append(1)
    
    # 冲击故障:周期性脉冲
    for _ in range(samples_per_class):
        signal = np.random.normal(0, 0.5, signal_length)
        for i in range(0, signal_length, 150):  # 每150个点添加冲击
            if i+20 < signal_length:
                signal[i:i+20] += 5.0  # 添加20个点的冲击
        data.append(signal)
        labels.append(2)
    
    return np.array(data), np.array(labels)

# 特征提取函数
def extract_features(signals):
    """
    从时域和频域提取9维特征:
    [均值,标准差,峰值,均方根,峭度,偏度,频域均值,频域标准差,主频位置]
    """
    features = []
    for signal in signals:
        # 时域特征
        mean = np.mean(signal)
        std = np.std(signal)
        peak = np.max(np.abs(signal))
        rms = np.sqrt(np.mean(signal**2))
        kurtosis = np.mean((signal - mean)**4) / (std**4 + 1e-6)
        skewness = np.mean((signal - mean)**3) / (std**3 + 1e-6)
        
        # 频域特征
        fft = np.abs(np.fft.fft(signal)[:len(signal)//2])  # 单边频谱
        fft_mean = np.mean(fft)
        fft_std = np.std(fft)
        dominant_freq = np.argmax(fft)
        
        features.append([mean, std, peak, rms, kurtosis, skewness, 
                        fft_mean, fft_std, dominant_freq])
    return np.array(features)

# 可视化原始信号和特征分布
def visualize_data(signals, labels, features):
    plt.figure(figsize=(15, 10))
    
    # 绘制原始信号示例
    plt.subplot(2, 2, 1)
    plt.plot(signals[0], label='Normal')
    plt.plot(signals[200], label='High Freq Fault')
    plt.plot(signals[400], label='Impact Fault')
    plt.title("Raw Signal Examples")
    plt.xlabel("Sample Points")
    plt.ylabel("Amplitude")
    plt.legend()
    
    # 绘制特征分布
    plt.subplot(2, 2, 2)
    scatter = plt.scatter(features[:, 0], features[:, 1], c=labels, alpha=0.6)
    plt.xlabel('Mean')
    plt.ylabel('Std')
    plt.title('Feature Space Distribution')
    plt.legend(*scatter.legend_elements(), title='Classes')
    
    # 绘制频谱对比
    plt.subplot(2, 2, 3)
    for i in [0, 200, 400]:
        fft = np.abs(np.fft.fft(signals[i])[:512])
        plt.plot(fft, label=f'Class {labels[i]}')
    plt.title("Frequency Spectrum Comparison")
    plt.xlabel("Frequency Bins")
    plt.ylabel("Amplitude")
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# 主程序
if __name__ == "__main__":
    # 1. 数据生成
    signals, labels = generate_vibration_data(samples_per_class=200)
    print("Signal data shape:", signals.shape)
    print("Label distribution:", np.unique(labels, return_counts=True))
    
    # 2. 特征提取
    features = extract_features(signals)
    print("Feature shape:", features.shape)
    
    # 3. 数据可视化
    visualize_data(signals, labels, features)
    
    # 4. 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(
        features, labels, test_size=0.3, random_state=42
    )
    
    # 5. 初始化DBN分类器
    classifier = SupervisedDBNClassification(
        hidden_layers_structure=[64, 32],  # 两个隐藏层
        learning_rate_rbm=0.01,           # RBM学习率
        learning_rate=0.05,               # 微调学习率
        n_epochs_rbm=15,                  # RBM训练轮次
        n_iter_backprop=200,              # 反向传播迭代次数
        batch_size=32,                    # 批处理大小
        activation_function='relu',       # 激活函数
        dropout_p=0.25                    # Dropout概率
    )
    
    # 6. 训练模型
    print("\nStarting training...")
    classifier.fit(X_train, y_train)
    
    # 7. 模型评估
    y_pred = classifier.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"\nTest Accuracy: {accuracy:.4f}")
    
    # 8. 绘制混淆矩阵
    cm = confusion_matrix(y_test, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.show()
```

### 代码说明

1. **数据生成**:
   - 生成包含3类(正常、高频故障、冲击故障)的模拟振动信号
   - 每类生成200个样本,每个样本包含1024个数据点
   - 通过添加不同的故障特征(高频成分、周期性冲击)来模拟真实故障

2. **特征提取**:
   - 时域特征:均值、标准差、峰值、均方根、峭度、偏度
   - 频域特征:FFT均值、FFT标准差、主频位置
   - 最终生成9维特征向量

3. **可视化功能**:
   - 原始信号波形对比
   - 特征空间分布
   - 频谱特征对比
   - 混淆矩阵可视化

4. **DBN模型**:
   - 使用两层隐藏层(64-32节点)
   - 包含预训练(RBM)和微调两个阶段
   - 使用ReLU激活函数和Dropout正则化

### 运行说明

1. 安装依赖库:
   ```bash
   pip install numpy matplotlib scikit-learn dbn==0.0.3 tensorflow==1.15.0
   ```

2. 运行程序将显示:
   - 原始信号波形对比图
   - 特征分布散点图
   - 频谱对比图
   - 混淆矩阵

3. 典型输出:
   ```
   Signal data shape: (600, 1024)
   Label distribution: (array([0, 1, 2]), array([200, 200, 200]))
   Feature shape: (600, 9)
   Test Accuracy: 0.9833
   ```

### 注意事项

1. 该实现使用TensorFlow 1.15,确保运行环境正确
2. 可根据实际数据调整以下参数:
   - 生成信号的特征参数
   - DBN网络结构
   - 训练超参数(学习率、迭代次数等)
3. 对于真实数据,需要调整特征提取方法和数据标准化方式

此程序提供了一个完整的DBN故障诊断框架,能够处理振动信号数据并进行可视化分析,用户可以根据实际需求进行参数调整和功能扩展。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ydlhnust

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值