一个基于卷积神经网络(CNN)的心电图(ECG)心律失常分类模型

  • 数据加载与预处理

    • 使用 wfdb 库加载 MIT-BIH 心律失常数据库中的心电信号数据。
    • 应用低通滤波器去除高频噪声,确保信号的平滑性。
    • 通过峰值检测提取心跳信号,并对心跳进行截取和标记。
  • 标签生成

    • 读取心电图的注释信息,根据心跳的类型(正常或异常)为每个心跳生成相应的标签。
  • 数据填充与格式转换

    • 将提取的心跳信号填充到相同的长度,以便于输入到模型中。
    • 将数据划分为训练集和测试集,并使用随机过采样(RandomOverSampler)处理类别不平衡问题。
  • 模型构建与训练

    • 定义一个包含两个卷积层和池化层的 CNN 模型,用于提取时序特征。
    • 使用二元交叉熵损失函数和 Adam 优化器编译模型,并训练模型。
  • 模型评估与预测

    • 在测试集上评估模型的性能,计算损失和准确率。
    • 进行预测并将结果与真实标签进行对比,输出每个片段的实际和预测标签。
  • 可视化

    • 绘制实际标签和预测标签的对比图,展示模型的分类效果。
    • 可视化训练过程中的准确率和损失变化,帮助分析模型的学习情况
import wfdb
from scipy.signal import butter, filtfilt, find_peaks
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from imblearn.over_sampling import RandomOverSampler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense

# 低通滤波器,去除高频噪声
def butter_lowpass_filter(data, cutoff, fs, order=5):
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y

# 加载数据
def load_data(record_path):
    record = wfdb.rdrecord(record_path)
    return record

# 处理心拍数据
def preprocess_data(record):
    filtered_signal = butter_lowpass_filter(record.p_signal[:, 0], cutoff=50, fs=record.fs)
    peaks, _ = find_peaks(filtered_signal, height=0.5, distance=record.fs * 0.6)

    heartbeat_length = int(record.fs)  # 1 秒的长度
    heartbeats = []
    valid_peaks = []

    for peak in peaks:
        start = peak - heartbeat_length // 2
        end = peak + heartbeat_length // 2
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值