一个基于卷积神经网络(CNN)的心电图(ECG)心律失常分类模型

  • 数据加载与预处理

    • 使用 wfdb 库加载 MIT-BIH 心律失常数据库中的心电信号数据。
    • 应用低通滤波器去除高频噪声,确保信号的平滑性。
    • 通过峰值检测提取心跳信号,并对心跳进行截取和标记。
  • 标签生成

    • 读取心电图的注释信息,根据心跳的类型(正常或异常)为每个心跳生成相应的标签。
  • 数据填充与格式转换

    • 将提取的心跳信号填充到相同的长度,以便于输入到模型中。
    • 将数据划分为训练集和测试集,并使用随机过采样(RandomOverSampler)处理类别不平衡问题。
  • 模型构建与训练

    • 定义一个包含两个卷积层和池化层的 CNN 模型,用于提取时序特征。
    • 使用二元交叉熵损失函数和 Adam 优化器编译模型,并训练模型。
  • 模型评估与预测

    • 在测试集上评估模型的性能,计算损失和准确率。
    • 进行预测并将结果与真实标签进行对比,输出每个片段的实际和预测标签。
  • 可视化

    • 绘制实际标签和预测标签的对比图,展示模型的分类效果。
    • 可视化训练过程中的准确率和损失变化,帮助分析模型的学习情况
import wfdb
from scipy.signal import butter, filtfilt, find_peaks
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from imblearn.over_sampling import RandomOverSampler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense

# 低通滤波器,去除高频噪声
def butter_lowpass_filter(data, cutoff, fs, order=5):
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y

# 加载数据
def load_data(record_path):
    record = wfdb.rdrecord(record_path)
    return record

# 处理心拍数据
def preprocess_data(record):
    filtered_signal = butter_lowpass_filter(record.p_signal[:, 0], cutoff=50, fs=record.fs)
    peaks, _ = find_peaks(filtered_signal, height=0.5, distance=record.fs * 0.6)

    heartbeat_length = int(record.fs)  # 1 秒的长度
    heartbeats = []
    valid_peaks = []

    for peak in peaks:
        start = peak - heartbeat_length // 2
        end = peak + heartbeat_length // 2
        if start >= 0 and end < len(filtered_signal):  # 确保不越界
            heartbeats.append(filtered_signal[start:end])
            valid_peaks.append(peak)  # 仅添加有效峰值
        else:
            print(f"峰值索引 {peak} 超出范围,跳过此心拍。")

    return np.array(heartbeats), valid_peaks

def create_labels(record, valid_peaks, heartbeat_length):
    # 使用 record.record_name 获取完整的文件名
    record_name = record.record_name  # 例如 '100'
    annotation = wfdb.rdann(f'mit-bih-arrhythmia-database-1.0.0/{record_name}', 'atr')
    labels = annotation.symbol
    sample_indices = annotation.sample

    automatic_labels = np.zeros(len(record.p_signal))  # 默认标记为正常(0)
    for i in range(len(sample_indices)):
        sample_index = sample_indices[i]
        if labels[i] in ['N', 'L', 'R']:
            automatic_labels[sample_index] = 0  # 正常心跳
        elif labels[i] in ['A', 'S', 'V', 'F', 'J']:
            automatic_labels[sample_index] = 1  # 心律失常

    segment_labels = []
    for peak in valid_peaks:
        start_index = peak - heartbeat_length // 2
        end_index = peak + heartbeat_length // 2
        if start_index < 0 or end_index >= len(automatic_labels):
            print(f"峰值索引 {peak} 超出范围,跳过此标注。")
            continue

        # 检查当前心拍的标签
        if np.any(automatic_labels[start_index:end_index] == 1):
            segment_labels.append(1)
        else:
            segment_labels.append(0)

    return np.array(segment_labels)


# 主函数
def main():
    record_path = 'mit-bih-arrhythmia-database-1.0.0/201'
    record = load_data(record_path)

    # 数据预处理
    heartbeats, valid_peaks = preprocess_data(record)

    # 假设你在主函数中定义了 heartbeat_length
    heartbeat_length = int(record.fs)  # 1 秒的长度
    segment_labels = create_labels(record, valid_peaks, heartbeat_length)

    # 填充到相同长度
    heartbeat_length = int(record.fs)
    heartbeats_padded = [np.pad(heartbeat, (0, max(0, heartbeat_length - len(heartbeat))), 'constant') for heartbeat in heartbeats]

    # 转换为 numpy 数组
    X = np.array(heartbeats_padded)
    X = np.expand_dims(X, axis=-1)  # 添加维度

    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, segment_labels, test_size=0.3, random_state=43)

    # 上采样异常标签
    ros = RandomOverSampler(random_state=42)
    X_resampled, y_resampled = ros.fit_resample(X_train.reshape(X_train.shape[0], -1), y_train)

    # 转换为原始形状
    X_resampled = X_resampled.reshape(-1, heartbeat_length, 1)

    # 定义模型
    model = Sequential([
        Conv1D(filters=32, kernel_size=3, activation='relu', input_shape=(heartbeat_length, 1)),
        MaxPooling1D(pool_size=2),
        Conv1D(filters=64, kernel_size=3, activation='relu'),
        MaxPooling1D(pool_size=2),
        Flatten(),
        Dense(64, activation='relu'),
        Dense(1, activation='sigmoid')  # 输出层,假设是二分类问题
    ])

    # 编译模型
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    # 训练模型
    history = model.fit(X_resampled, y_resampled, epochs=10, batch_size=32, validation_data=(X_test, y_test))

    # 评估模型
    loss, accuracy = model.evaluate(X_test, y_test)

    # 保存模型
    model.save('ecg_model.h5')

    # 进行预测
    predictions = model.predict(X_test)

    # 将预测结果转换为类别标签
    predicted_labels = (predictions > 0.5).astype(int)

    # 可视化实际结果与预测结果
    plt.figure(figsize=(12, 6))

    # 真实标签
    plt.subplot(1, 2, 1)
    plt.plot(y_test, marker='o', label='Actual', color='b')
    plt.title('Actual Labels')
    plt.xlabel('Segment Index')
    plt.ylabel('Label')
    plt.legend()

    # 预测标签
    plt.subplot(1, 2, 2)
    plt.plot(predicted_labels, marker='o', label='Predicted', color='r')
    plt.title('Predicted Labels')
    plt.xlabel('Segment Index')
    plt.ylabel('Label')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # 打印对比结果
    for i in range(len(y_test)):
        print(f'Segment {i}: 实际结果 = {y_test[i]}, 预测结果 = {predicted_labels[i][0]}')

    print(f'Test Loss: {loss:.4f}, Test Accuracy: {accuracy:.4f}')

    # 可视化训练过程
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Test Accuracy')
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend()
    plt.show()

    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Test Loss')
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend()
    plt.show()

# 运行主函数
if __name__ == '__main__':
    main()

 训练代码如上。

预测代码如下。

 

import wfdb
import numpy as np
from scipy.signal import butter, filtfilt, find_peaks
from tensorflow.keras.models import load_model

# 定义低通滤波器函数
def butter_lowpass_filter(data, cutoff, fs, order=5):
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y

# 创建标签函数
def create_labels(record, valid_peaks, heartbeat_length):
    record_name = record.record_name
    annotation = wfdb.rdann(f'mit-bih-arrhythmia-database-1.0.0/{record_name}', 'atr')
    labels = annotation.symbol
    sample_indices = annotation.sample

    automatic_labels = np.zeros(len(record.p_signal))  # 默认标记为正常(0)
    for i in range(len(sample_indices)):
        sample_index = sample_indices[i]
        if labels[i] in ['N', 'L', 'R']:
            automatic_labels[sample_index] = 0  # 正常心跳
        elif labels[i] in ['A', 'S', 'V', 'F', 'J']:
            automatic_labels[sample_index] = 1  # 心律失常

    segment_labels = []
    for peak in valid_peaks:
        start_index = peak - heartbeat_length // 2
        end_index = peak + heartbeat_length // 2
        if start_index < 0 or end_index >= len(automatic_labels):
            print(f"峰值索引 {peak} 超出范围,跳过此标注。")
            continue

        if np.any(automatic_labels[start_index:end_index] == 1):
            segment_labels.append(1)
        else:
            segment_labels.append(0)

    return np.array(segment_labels)

# 加载模型
model = load_model('ecg_model.h5')

# 加载 ECG 数据
record_path = 'mit-bih-arrhythmia-database-1.0.0/201'
record = wfdb.rdrecord(record_path)

# 预处理数据
filtered_signal = butter_lowpass_filter(record.p_signal[:, 0], cutoff=50, fs=record.fs)

# 检测峰值
peaks, _ = find_peaks(filtered_signal, height=0.5, distance=record.fs * 0.6)

# 划分心拍
heartbeat_length = int(record.fs)  # 1 秒的长度
heartbeats = []

for peak in peaks:
    start = peak - heartbeat_length // 2
    end = peak + heartbeat_length // 2
    if start >= 0 and end < len(filtered_signal):
        heartbeats.append(filtered_signal[start:end])

# 转换为 numpy 数组,并进行填充
heartbeats_padded = [np.pad(heartbeat, (0, max(0, heartbeat_length - len(heartbeat))), 'constant') for heartbeat in heartbeats]
X = np.array(heartbeats_padded)
X = np.expand_dims(X, axis=-1)

# 进行预测
predictions = model.predict(X)

# 将预测结果转换为类别标签
predicted_labels = (predictions > 0.5).astype(int)

# 创建标签
valid_peaks = peaks  # 使用前面检测的峰值
segment_labels = create_labels(record, valid_peaks, heartbeat_length)

# 确保预测和标签的长度一致
min_length = min(len(segment_labels), len(predicted_labels))
segment_labels = segment_labels[:min_length]
predicted_labels = predicted_labels[:min_length]

# 逐个输出预测和实际标签
for i in range(min_length):
    print(f'心拍 {i}: 预测结果 = {predicted_labels[i][0]}, 实际标签 = {segment_labels[i]}')

# 计算准确率
accuracy = np.sum(segment_labels == predicted_labels.flatten()) / min_length
print(f'模型预测准确率: {accuracy:.4f}')

# 逐个输出预测和实际标签,并打印特定情况的心拍位置
for i in range(min_length):
    actual = segment_labels[i]
    predicted = predicted_labels[i][0]

    if actual == 1 and predicted == 1:
        print(f'心拍 {i}: 实际为 1,预测为 1')
    elif actual == 1 and predicted == 0:
        print(f'心拍 {i}: 实际为 1,预测为 0')
    elif actual == 0 and predicted == 1:
        print(f'心拍 {i}: 实际为 0,预测为 1')

 运行时候,需要修改对应电脑的数据库:

record_path = 'mit-bih-arrhythmia-database-1.0.0/201'

测试集结果:

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值