-
数据加载与预处理:
- 使用
wfdb
库加载 MIT-BIH 心律失常数据库中的心电信号数据。 - 应用低通滤波器去除高频噪声,确保信号的平滑性。
- 通过峰值检测提取心跳信号,并对心跳进行截取和标记。
- 使用
-
标签生成:
- 读取心电图的注释信息,根据心跳的类型(正常或异常)为每个心跳生成相应的标签。
-
数据填充与格式转换:
- 将提取的心跳信号填充到相同的长度,以便于输入到模型中。
- 将数据划分为训练集和测试集,并使用随机过采样(RandomOverSampler)处理类别不平衡问题。
-
模型构建与训练:
- 定义一个包含两个卷积层和池化层的 CNN 模型,用于提取时序特征。
- 使用二元交叉熵损失函数和 Adam 优化器编译模型,并训练模型。
-
模型评估与预测:
- 在测试集上评估模型的性能,计算损失和准确率。
- 进行预测并将结果与真实标签进行对比,输出每个片段的实际和预测标签。
-
可视化:
- 绘制实际标签和预测标签的对比图,展示模型的分类效果。
- 可视化训练过程中的准确率和损失变化,帮助分析模型的学习情况
import wfdb
from scipy.signal import butter, filtfilt, find_peaks
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from imblearn.over_sampling import RandomOverSampler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense
# 低通滤波器,去除高频噪声
def butter_lowpass_filter(data, cutoff, fs, order=5):
nyquist = 0.5 * fs
normal_cutoff = cutoff / nyquist
b, a = butter(order, normal_cutoff, btype='low', analog=False)
y = filtfilt(b, a, data)
return y
# 加载数据
def load_data(record_path):
record = wfdb.rdrecord(record_path)
return record
# 处理心拍数据
def preprocess_data(record):
filtered_signal = butter_lowpass_filter(record.p_signal[:, 0], cutoff=50, fs=record.fs)
peaks, _ = find_peaks(filtered_signal, height=0.5, distance=record.fs * 0.6)
heartbeat_length = int(record.fs) # 1 秒的长度
heartbeats = []
valid_peaks = []
for peak in peaks:
start = peak - heartbeat_length // 2
end = peak + heartbeat_length // 2