以下是基于DBN的故障诊断程序的完整实现,包含详细注释和可视化功能:
```python
# -*- coding: utf-8 -*-
"""
基于深度信念网络(DBN)的振动信号故障诊断系统
运行环境要求:
Python 3.6+
需要安装以下库:
pip install numpy matplotlib scikit-learn dbn==0.0.3 tensorflow==1.15.0
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from dbn.tensorflow import SupervisedDBNClassification
# 模拟振动信号生成函数
def generate_vibration_data(samples_per_class=200, signal_length=1024):
"""
生成包含3种状态的模拟振动信号:
0 - 正常状态
1 - 高频故障
2 - 冲击故障
"""
np.random.seed(42)
data = []
labels = []
# 正常状态:高斯白噪声
for _ in range(samples_per_class):
signal = np.random.normal(0, 0.5, signal_length)
data.append(signal)
labels.append(0)
# 高频故障:叠加高频正弦波
for _ in range(samples_per_class):
t = np.linspace(0, 2*np.pi, signal_length)
noise = np.random.normal(0, 0.5, signal_length)
signal = noise + 0.8*np.sin(100*t)
data.append(signal)
labels.append(1)
# 冲击故障:周期性脉冲
for _ in range(samples_per_class):
signal = np.random.normal(0, 0.5, signal_length)
for i in range(0, signal_length, 150): # 每150个点添加冲击
if i+20 < signal_length:
signal[i:i+20] += 5.0 # 添加20个点的冲击
data.append(signal)
labels.append(2)
return np.array(data), np.array(labels)
# 特征提取函数
def extract_features(signals):
"""
从时域和频域提取9维特征:
[均值,标准差,峰值,均方根,峭度,偏度,频域均值,频域标准差,主频位置]
"""
features = []
for signal in signals:
# 时域特征
mean = np.mean(signal)
std = np.std(signal)
peak = np.max(np.abs(signal))
rms = np.sqrt(np.mean(signal**2))
kurtosis = np.mean((signal - mean)**4) / (std**4 + 1e-6)
skewness = np.mean((signal - mean)**3) / (std**3 + 1e-6)
# 频域特征
fft = np.abs(np.fft.fft(signal)[:len(signal)//2]) # 单边频谱
fft_mean = np.mean(fft)
fft_std = np.std(fft)
dominant_freq = np.argmax(fft)
features.append([mean, std, peak, rms, kurtosis, skewness,
fft_mean, fft_std, dominant_freq])
return np.array(features)
# 可视化原始信号和特征分布
def visualize_data(signals, labels, features):
plt.figure(figsize=(15, 10))
# 绘制原始信号示例
plt.subplot(2, 2, 1)
plt.plot(signals[0], label='Normal')
plt.plot(signals[200], label='High Freq Fault')
plt.plot(signals[400], label='Impact Fault')
plt.title("Raw Signal Examples")
plt.xlabel("Sample Points")
plt.ylabel("Amplitude")
plt.legend()
# 绘制特征分布
plt.subplot(2, 2, 2)
scatter = plt.scatter(features[:, 0], features[:, 1], c=labels, alpha=0.6)
plt.xlabel('Mean')
plt.ylabel('Std')
plt.title('Feature Space Distribution')
plt.legend(*scatter.legend_elements(), title='Classes')
# 绘制频谱对比
plt.subplot(2, 2, 3)
for i in [0, 200, 400]:
fft = np.abs(np.fft.fft(signals[i])[:512])
plt.plot(fft, label=f'Class {labels[i]}')
plt.title("Frequency Spectrum Comparison")
plt.xlabel("Frequency Bins")
plt.ylabel("Amplitude")
plt.legend()
plt.tight_layout()
plt.show()
# 主程序
if __name__ == "__main__":
# 1. 数据生成
signals, labels = generate_vibration_data(samples_per_class=200)
print("Signal data shape:", signals.shape)
print("Label distribution:", np.unique(labels, return_counts=True))
# 2. 特征提取
features = extract_features(signals)
print("Feature shape:", features.shape)
# 3. 数据可视化
visualize_data(signals, labels, features)
# 4. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
features, labels, test_size=0.3, random_state=42
)
# 5. 初始化DBN分类器
classifier = SupervisedDBNClassification(
hidden_layers_structure=[64, 32], # 两个隐藏层
learning_rate_rbm=0.01, # RBM学习率
learning_rate=0.05, # 微调学习率
n_epochs_rbm=15, # RBM训练轮次
n_iter_backprop=200, # 反向传播迭代次数
batch_size=32, # 批处理大小
activation_function='relu', # 激活函数
dropout_p=0.25 # Dropout概率
)
# 6. 训练模型
print("\nStarting training...")
classifier.fit(X_train, y_train)
# 7. 模型评估
y_pred = classifier.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"\nTest Accuracy: {accuracy:.4f}")
# 8. 绘制混淆矩阵
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()
```
### 代码说明
1. **数据生成**:
- 生成包含3类(正常、高频故障、冲击故障)的模拟振动信号
- 每类生成200个样本,每个样本包含1024个数据点
- 通过添加不同的故障特征(高频成分、周期性冲击)来模拟真实故障
2. **特征提取**:
- 时域特征:均值、标准差、峰值、均方根、峭度、偏度
- 频域特征:FFT均值、FFT标准差、主频位置
- 最终生成9维特征向量
3. **可视化功能**:
- 原始信号波形对比
- 特征空间分布
- 频谱特征对比
- 混淆矩阵可视化
4. **DBN模型**:
- 使用两层隐藏层(64-32节点)
- 包含预训练(RBM)和微调两个阶段
- 使用ReLU激活函数和Dropout正则化
### 运行说明
1. 安装依赖库:
```bash
pip install numpy matplotlib scikit-learn dbn==0.0.3 tensorflow==1.15.0
```
2. 运行程序将显示:
- 原始信号波形对比图
- 特征分布散点图
- 频谱对比图
- 混淆矩阵
3. 典型输出:
```
Signal data shape: (600, 1024)
Label distribution: (array([0, 1, 2]), array([200, 200, 200]))
Feature shape: (600, 9)
Test Accuracy: 0.9833
```
### 注意事项
1. 该实现使用TensorFlow 1.15,确保运行环境正确
2. 可根据实际数据调整以下参数:
- 生成信号的特征参数
- DBN网络结构
- 训练超参数(学习率、迭代次数等)
3. 对于真实数据,需要调整特征提取方法和数据标准化方式
此程序提供了一个完整的DBN故障诊断框架,能够处理振动信号数据并进行可视化分析,用户可以根据实际需求进行参数调整和功能扩展。