-
数据加载与预处理:
- 使用
wfdb
库加载 MIT-BIH 心律失常数据库中的心电信号数据。 - 应用低通滤波器去除高频噪声,确保信号的平滑性。
- 通过峰值检测提取心跳信号,并对心跳进行截取和标记。
- 使用
-
标签生成:
- 读取心电图的注释信息,根据心跳的类型(正常或异常)为每个心跳生成相应的标签。
-
数据填充与格式转换:
- 将提取的心跳信号填充到相同的长度,以便于输入到模型中。
- 将数据划分为训练集和测试集,并使用随机过采样(RandomOverSampler)处理类别不平衡问题。
-
模型构建与训练:
- 定义一个包含两个卷积层和池化层的 CNN 模型,用于提取时序特征。
- 使用二元交叉熵损失函数和 Adam 优化器编译模型,并训练模型。
-
模型评估与预测:
- 在测试集上评估模型的性能,计算损失和准确率。
- 进行预测并将结果与真实标签进行对比,输出每个片段的实际和预测标签。
-
可视化:
- 绘制实际标签和预测标签的对比图,展示模型的分类效果。
- 可视化训练过程中的准确率和损失变化,帮助分析模型的学习情况
import wfdb
from scipy.signal import butter, filtfilt, find_peaks
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from imblearn.over_sampling import RandomOverSampler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense
# 低通滤波器,去除高频噪声
def butter_lowpass_filter(data, cutoff, fs, order=5):
nyquist = 0.5 * fs
normal_cutoff = cutoff / nyquist
b, a = butter(order, normal_cutoff, btype='low', analog=False)
y = filtfilt(b, a, data)
return y
# 加载数据
def load_data(record_path):
record = wfdb.rdrecord(record_path)
return record
# 处理心拍数据
def preprocess_data(record):
filtered_signal = butter_lowpass_filter(record.p_signal[:, 0], cutoff=50, fs=record.fs)
peaks, _ = find_peaks(filtered_signal, height=0.5, distance=record.fs * 0.6)
heartbeat_length = int(record.fs) # 1 秒的长度
heartbeats = []
valid_peaks = []
for peak in peaks:
start = peak - heartbeat_length // 2
end = peak + heartbeat_length // 2
if start >= 0 and end < len(filtered_signal): # 确保不越界
heartbeats.append(filtered_signal[start:end])
valid_peaks.append(peak) # 仅添加有效峰值
else:
print(f"峰值索引 {peak} 超出范围,跳过此心拍。")
return np.array(heartbeats), valid_peaks
def create_labels(record, valid_peaks, heartbeat_length):
# 使用 record.record_name 获取完整的文件名
record_name = record.record_name # 例如 '100'
annotation = wfdb.rdann(f'mit-bih-arrhythmia-database-1.0.0/{record_name}', 'atr')
labels = annotation.symbol
sample_indices = annotation.sample
automatic_labels = np.zeros(len(record.p_signal)) # 默认标记为正常(0)
for i in range(len(sample_indices)):
sample_index = sample_indices[i]
if labels[i] in ['N', 'L', 'R']:
automatic_labels[sample_index] = 0 # 正常心跳
elif labels[i] in ['A', 'S', 'V', 'F', 'J']:
automatic_labels[sample_index] = 1 # 心律失常
segment_labels = []
for peak in valid_peaks:
start_index = peak - heartbeat_length // 2
end_index = peak + heartbeat_length // 2
if start_index < 0 or end_index >= len(automatic_labels):
print(f"峰值索引 {peak} 超出范围,跳过此标注。")
continue
# 检查当前心拍的标签
if np.any(automatic_labels[start_index:end_index] == 1):
segment_labels.append(1)
else:
segment_labels.append(0)
return np.array(segment_labels)
# 主函数
def main():
record_path = 'mit-bih-arrhythmia-database-1.0.0/201'
record = load_data(record_path)
# 数据预处理
heartbeats, valid_peaks = preprocess_data(record)
# 假设你在主函数中定义了 heartbeat_length
heartbeat_length = int(record.fs) # 1 秒的长度
segment_labels = create_labels(record, valid_peaks, heartbeat_length)
# 填充到相同长度
heartbeat_length = int(record.fs)
heartbeats_padded = [np.pad(heartbeat, (0, max(0, heartbeat_length - len(heartbeat))), 'constant') for heartbeat in heartbeats]
# 转换为 numpy 数组
X = np.array(heartbeats_padded)
X = np.expand_dims(X, axis=-1) # 添加维度
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, segment_labels, test_size=0.3, random_state=43)
# 上采样异常标签
ros = RandomOverSampler(random_state=42)
X_resampled, y_resampled = ros.fit_resample(X_train.reshape(X_train.shape[0], -1), y_train)
# 转换为原始形状
X_resampled = X_resampled.reshape(-1, heartbeat_length, 1)
# 定义模型
model = Sequential([
Conv1D(filters=32, kernel_size=3, activation='relu', input_shape=(heartbeat_length, 1)),
MaxPooling1D(pool_size=2),
Conv1D(filters=64, kernel_size=3, activation='relu'),
MaxPooling1D(pool_size=2),
Flatten(),
Dense(64, activation='relu'),
Dense(1, activation='sigmoid') # 输出层,假设是二分类问题
])
# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 训练模型
history = model.fit(X_resampled, y_resampled, epochs=10, batch_size=32, validation_data=(X_test, y_test))
# 评估模型
loss, accuracy = model.evaluate(X_test, y_test)
# 保存模型
model.save('ecg_model.h5')
# 进行预测
predictions = model.predict(X_test)
# 将预测结果转换为类别标签
predicted_labels = (predictions > 0.5).astype(int)
# 可视化实际结果与预测结果
plt.figure(figsize=(12, 6))
# 真实标签
plt.subplot(1, 2, 1)
plt.plot(y_test, marker='o', label='Actual', color='b')
plt.title('Actual Labels')
plt.xlabel('Segment Index')
plt.ylabel('Label')
plt.legend()
# 预测标签
plt.subplot(1, 2, 2)
plt.plot(predicted_labels, marker='o', label='Predicted', color='r')
plt.title('Predicted Labels')
plt.xlabel('Segment Index')
plt.ylabel('Label')
plt.legend()
plt.tight_layout()
plt.show()
# 打印对比结果
for i in range(len(y_test)):
print(f'Segment {i}: 实际结果 = {y_test[i]}, 预测结果 = {predicted_labels[i][0]}')
print(f'Test Loss: {loss:.4f}, Test Accuracy: {accuracy:.4f}')
# 可视化训练过程
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Test Accuracy')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.show()
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Test Loss')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.show()
# 运行主函数
if __name__ == '__main__':
main()
训练代码如上。
预测代码如下。
import wfdb
import numpy as np
from scipy.signal import butter, filtfilt, find_peaks
from tensorflow.keras.models import load_model
# 定义低通滤波器函数
def butter_lowpass_filter(data, cutoff, fs, order=5):
nyquist = 0.5 * fs
normal_cutoff = cutoff / nyquist
b, a = butter(order, normal_cutoff, btype='low', analog=False)
y = filtfilt(b, a, data)
return y
# 创建标签函数
def create_labels(record, valid_peaks, heartbeat_length):
record_name = record.record_name
annotation = wfdb.rdann(f'mit-bih-arrhythmia-database-1.0.0/{record_name}', 'atr')
labels = annotation.symbol
sample_indices = annotation.sample
automatic_labels = np.zeros(len(record.p_signal)) # 默认标记为正常(0)
for i in range(len(sample_indices)):
sample_index = sample_indices[i]
if labels[i] in ['N', 'L', 'R']:
automatic_labels[sample_index] = 0 # 正常心跳
elif labels[i] in ['A', 'S', 'V', 'F', 'J']:
automatic_labels[sample_index] = 1 # 心律失常
segment_labels = []
for peak in valid_peaks:
start_index = peak - heartbeat_length // 2
end_index = peak + heartbeat_length // 2
if start_index < 0 or end_index >= len(automatic_labels):
print(f"峰值索引 {peak} 超出范围,跳过此标注。")
continue
if np.any(automatic_labels[start_index:end_index] == 1):
segment_labels.append(1)
else:
segment_labels.append(0)
return np.array(segment_labels)
# 加载模型
model = load_model('ecg_model.h5')
# 加载 ECG 数据
record_path = 'mit-bih-arrhythmia-database-1.0.0/201'
record = wfdb.rdrecord(record_path)
# 预处理数据
filtered_signal = butter_lowpass_filter(record.p_signal[:, 0], cutoff=50, fs=record.fs)
# 检测峰值
peaks, _ = find_peaks(filtered_signal, height=0.5, distance=record.fs * 0.6)
# 划分心拍
heartbeat_length = int(record.fs) # 1 秒的长度
heartbeats = []
for peak in peaks:
start = peak - heartbeat_length // 2
end = peak + heartbeat_length // 2
if start >= 0 and end < len(filtered_signal):
heartbeats.append(filtered_signal[start:end])
# 转换为 numpy 数组,并进行填充
heartbeats_padded = [np.pad(heartbeat, (0, max(0, heartbeat_length - len(heartbeat))), 'constant') for heartbeat in heartbeats]
X = np.array(heartbeats_padded)
X = np.expand_dims(X, axis=-1)
# 进行预测
predictions = model.predict(X)
# 将预测结果转换为类别标签
predicted_labels = (predictions > 0.5).astype(int)
# 创建标签
valid_peaks = peaks # 使用前面检测的峰值
segment_labels = create_labels(record, valid_peaks, heartbeat_length)
# 确保预测和标签的长度一致
min_length = min(len(segment_labels), len(predicted_labels))
segment_labels = segment_labels[:min_length]
predicted_labels = predicted_labels[:min_length]
# 逐个输出预测和实际标签
for i in range(min_length):
print(f'心拍 {i}: 预测结果 = {predicted_labels[i][0]}, 实际标签 = {segment_labels[i]}')
# 计算准确率
accuracy = np.sum(segment_labels == predicted_labels.flatten()) / min_length
print(f'模型预测准确率: {accuracy:.4f}')
# 逐个输出预测和实际标签,并打印特定情况的心拍位置
for i in range(min_length):
actual = segment_labels[i]
predicted = predicted_labels[i][0]
if actual == 1 and predicted == 1:
print(f'心拍 {i}: 实际为 1,预测为 1')
elif actual == 1 and predicted == 0:
print(f'心拍 {i}: 实际为 1,预测为 0')
elif actual == 0 and predicted == 1:
print(f'心拍 {i}: 实际为 0,预测为 1')
运行时候,需要修改对应电脑的数据库:
record_path = 'mit-bih-arrhythmia-database-1.0.0/201'
测试集结果: