制作医疗健康检测 AI 模型

🚑 制作一个医疗健康检测 AI 模型

医疗健康检测模型可以用于 疾病预测、医学影像分析、生理信号监测 等任务。

本文介绍如何使用 深度学习(CNN/LSTM)+ 医疗数据 构建高效的医疗 AI 模型。


📌 1. 选择 AI 模型类型

医疗健康检测涉及多个方向:

  1. 医学影像分析(X-ray、MRI、CT 等)
    • 适用模型:CNN(ResNet、EfficientNet)、ViT(Vision Transformer)
    • 应用:肺炎检测、肿瘤识别、骨折检测
  2. 生理信号监测(心电 ECG、脑电 EEG)
    • 适用模型:LSTM、Transformer、CNN
    • 应用:心律异常、癫痫预测
  3. 健康指标预测(糖尿病、心脏病风险)
    • 适用模型:XGBoost、随机森林、DNN
    • 应用:糖尿病预测、心血管风险评估

📌 2. 以「肺炎 X-ray 影像检测」为例

任务:使用 CNN 训练 AI 模型,识别 X-ray 影像中的 肺炎 是否存在。
数据集:使用 Kaggle「Chest X-ray」数据集(含「正常」和「肺炎」X-ray 图片)。


📌 3. 安装依赖

pip install tensorflow keras numpy matplotlib opencv-python pandas sklearn

📌 4. 数据预处理

加载 X-ray 影像数据,并进行数据增强

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 设置数据路径
train_dir = "chest_xray/train"
val_dir = "chest_xray/val"
test_dir = "chest_xray/test"

# 数据增强 & 归一化
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True
)

test_datagen = ImageDataGenerator(rescale=1./255)

# 加载数据集
train_data = train_datagen.flow_from_directory(train_dir, target_size=(224, 224), batch_size=32, class_mode='binary')
val_data = test_datagen.flow_from_directory(val_dir, target_size=(224, 224), batch_size=32, class_mode='binary')
test_data = test_datagen.flow_from_directory(test_dir, target_size=(224, 224), batch_size=32, class_mode='binary')

📌 5. 构建 CNN 诊断模型

使用 EfficientNetB0(比 ResNet 轻量,准确率高)。

from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model

# 加载 EfficientNet 预训练模型
base_model = EfficientNetB0(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False  # 冻结预训练参数

# 构建自定义分类头
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(128, activation="relu")(x)
x = Dropout(0.3)(x)
output = Dense(1, activation="sigmoid")(x)

model = Model(inputs=base_model.input, outputs=output)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])

# 训练模型
history = model.fit(train_data, validation_data=val_data, epochs=10)

# 保存模型
model.save("pneumonia_detector.h5")

📌 6. 评估模型

# 计算测试集准确率
test_loss, test_acc = model.evaluate(test_data)
print(f"Test Accuracy: {test_acc * 100:.2f}%")

# 绘制训练过程
plt.plot(history.history["accuracy"], label="Train Accuracy")
plt.plot(history.history["val_accuracy"], label="Val Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

📌 7. 进行肺炎检测

import cv2

# 预测单张 X-ray 影像
def predict_image(img_path):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (224, 224))
    img = np.expand_dims(img / 255.0, axis=0)  # 归一化
    prediction = model.predict(img)[0][0]
    return "Pneumonia" if prediction > 0.5 else "Normal"

# 测试新图像
image_path = "chest_xray/test/NORMAL/image1.jpg"
print(f"Prediction: {predict_image(image_path)}")

📌 8. 部署 AI 诊断 API

使用 FastAPI 部署 AI 诊断服务:

from fastapi import FastAPI, UploadFile, File
import uvicorn
import numpy as np
from PIL import Image
import tensorflow as tf

app = FastAPI()
model = tf.keras.models.load_model("pneumonia_detector.h5")

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    img = Image.open(file.file).resize((224, 224))
    img = np.array(img) / 255.0
    img = np.expand_dims(img, axis=0)
    
    prediction = model.predict(img)[0][0]
    result = "Pneumonia" if prediction > 0.5 else "Normal"
    return {"prediction": result}

# 运行服务器
# uvicorn main:app --reload

测试 API

curl -X 'POST' 'http://127.0.0.1:8000/predict/' -F 'file=@image.jpg'

📌 9. 扩展思路

其他疾病预测(糖尿病、心脏病风险):

  • 数据集:UCI 糖尿病数据集、Framingham 心血管疾病数据
  • 模型:XGBoost / CNN-LSTM
    多模态医疗 AI(结合影像 + 生理信号)
    强化学习 + AI 诊断推荐

📌 10. 结论

🚀 通过 CNN(EfficientNet)+ 影像数据,我们构建了一个 肺炎检测 AI 模型,并通过 FastAPI 部署 作为在线诊断服务。
 

🚀 使用 Transformer 构建医疗健康检测 AI 模型

Transformer 模型已广泛应用于医学影像、文本分析、生理信号预测等任务。本文介绍如何使用 Vision Transformer(ViT)Transformer Encoder 进行医疗 AI 预测。


📌 1. Transformer 在医疗中的应用

Transformer 在医疗领域主要应用如下:
医学影像分析(X-ray、MRI、CT):Vision Transformer (ViT)
医学文本处理(电子病历、医学摘要):BERT / GPT
生理信号预测(心电 ECG、脑电 EEG):Transformer Encoder

我们以 肺炎 X-ray 影像分类(ViT) 为例,演示 Transformer 在医学影像分类中的应用


📌 2. 安装依赖

pip install torch torchvision timm numpy matplotlib

📌 3. 使用 Vision Transformer (ViT) 进行肺炎检测

ViT(Vision Transformer) 通过将图像划分为 Patch,并使用 Transformer 进行特征提取,适用于医学影像分析。

🔥 3.1 加载 ViT 预训练模型

import torch
import timm
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# 选择预训练的 ViT 模型
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=2)
model.eval()

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# 载入 X-ray 影像
image_path = "chest_xray/test/PNEUMONIA/image1.jpg"
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0)

# 预测肺炎
with torch.no_grad():
    output = model(image_tensor)
    pred_class = torch.argmax(output, dim=1).item()

print("预测结果:", "肺炎" if pred_class == 1 else "正常")

🔥 3.2 训练 ViT 进行肺炎检测

使用 ViT 进行迁移学习,在胸部 X-ray 数据集上训练模型:

import torch.nn as nn
from torch.optim import Adam
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# 加载数据集
train_data = ImageFolder("chest_xray/train", transform=transform)
val_data = ImageFolder("chest_xray/val", transform=transform)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(val_data, batch_size=16, shuffle=False)

# 修改 ViT 分类头
model.head = nn.Linear(model.head.in_features, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-4)

# 训练 ViT
for epoch in range(10):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

# 保存模型
torch.save(model.state_dict(), "vit_pneumonia.pth")

📌 4. Transformer 处理医学文本

BERT / GPT-4 可以用于医学文本分析,例如

  • 电子病历摘要(Medical Report Summarization)
  • 医疗问答(Medical Q&A)
  • 疾病预测(Disease Prediction)

🔥 4.1 电子病历摘要(BERT)

from transformers import BertTokenizer, BertForSequenceClassification
import torch

# 载入 BERT 预训练模型
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# 处理病历文本
text = "Patient is experiencing severe cough and high fever for 5 days. Possible pneumonia."
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

# 预测病情
with torch.no_grad():
    output = model(**inputs)
    prediction = torch.argmax(output.logits).item()

print("预测结果:", "肺炎" if prediction == 1 else "非肺炎")

📌 5. Transformer 进行生理信号分析(LSTM + Transformer)

任务:分析 ECG(心电图)数据,预测是否存在 心律失常

🔥 5.1 预处理 ECG 数据

import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler

# 载入 ECG 数据(示例)
data = pd.read_csv("ecg.csv")  # ECG 记录数据
scaler = MinMaxScaler()
ecg_signals = scaler.fit_transform(data.values)

🔥 5.2 训练 Transformer 进行预测

import torch.nn as nn

# 定义 Transformer 模型
class ECGTransformer(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=8), num_layers=4
        )
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        x = self.encoder(x)
        x = self.fc(x[:, -1, :])  # 取最后一个时间步
        return x

# 初始化模型
model = ECGTransformer(input_dim=128, num_classes=2)

📌 6. 结论

🚀 ViT 适用于医学影像,BERT/GPT 适用于医学文本,而 Transformer Encoder 可用于生理信号预测
 

🚀 结合多模态 AI-- 预测心脏病(ECG + 症状文本)

心脏病预测可结合 生理信号(ECG心电图)患者症状文本(电子病历),利用 Transformer 进行多模态分析,提高预测准确率。


📌 1. 方案概述

输入:

  • ECG(心电图数据):使用 LSTM + Transformer Encoder 提取特征
  • 病历文本(患者症状):使用 BERT 处理文本信息

输出

  • 预测是否患有心脏病(0: 正常, 1: 高风险)

📌 2. 安装依赖

pip install torch torchvision transformers numpy pandas scikit-learn

📌 3. ECG 心电图特征提取(LSTM + Transformer)

LSTM 适合处理时间序列数据,结合 Transformer Encoder 提取深度特征

import torch
import torch.nn as nn

class ECGTransformer(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256, num_layers=2, num_classes=128):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=2)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        transformer_out = self.transformer_encoder(lstm_out)
        feature = self.fc(transformer_out[:, -1, :])  # 取最后一个时间步
        return feature

# 初始化模型
ecg_model = ECGTransformer()

🔥 载入 ECG 数据

import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler

# 载入 ECG 数据
ecg_data = pd.read_csv("ecg.csv")  # ECG 心电数据
scaler = MinMaxScaler()
ecg_signals = scaler.fit_transform(ecg_data.values)  # 归一化

# 转换为 Tensor
ecg_tensor = torch.tensor(ecg_signals, dtype=torch.float32).unsqueeze(0)  # 添加 batch 维度

📌 4. 电子病历文本特征提取(BERT)

BERT 提取患者症状的文本特征

from transformers import BertTokenizer, BertModel

# 加载 BERT 预训练模型
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")

def extract_text_features(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    with torch.no_grad():
        features = bert_model(**inputs).pooler_output  # 取池化输出
    return features

🔥 载入病人症状

text_data = "Patient reports chest pain, high blood pressure, and irregular heartbeat."
text_features = extract_text_features(text_data)

📌 5. 多模态融合(ECG + 症状文本)

我们使用 全连接网络(FC) 进行多模态融合,并进行心脏病预测

class MultiModalHeartDiseaseModel(nn.Module):
    def __init__(self, ecg_feature_dim=128, text_feature_dim=768, num_classes=2):
        super().__init__()
        self.fc1 = nn.Linear(ecg_feature_dim + text_feature_dim, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, ecg_features, text_features):
        x = torch.cat((ecg_features, text_features), dim=1)  # 拼接特征
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化多模态 AI
multi_modal_model = MultiModalHeartDiseaseModel()

📌 6. 进行心脏病预测

# 提取 ECG 特征
with torch.no_grad():
    ecg_features = ecg_model(ecg_tensor)

# 预测心脏病
with torch.no_grad():
    output = multi_modal_model(ecg_features, text_features)
    prediction = torch.argmax(output).item()

print("心脏病预测:", "高风险" if prediction == 1 else "正常")

🚀 结合多模态 AI 预测心律失常(ECG + 电子病历)

心律失常(Arrhythmia)是一种常见的心脏疾病,可使用 ECG(心电图)数据 结合 电子病历文本 进行预测,提高诊断准确率。


📌 1. 方案概述

输入

  • ECG(心电图数据) → 采用 LSTM + Transformer 提取时序特征
  • 电子病历(患者症状) → 采用 BERT 处理文本特征

输出

  • 预测心律是否正常(0: 正常, 1: 心律失常)

📌 2. 安装依赖

pip install torch torchvision transformers numpy pandas scikit-learn

📌 3. ECG 信号特征提取(LSTM + Transformer)

ECG 是时间序列数据,使用 LSTM 处理长时依赖信息,并结合 Transformer 提取更深层次的特征

import torch
import torch.nn as nn

class ECGTransformer(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256, num_layers=2, num_classes=128):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=2)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        transformer_out = self.transformer_encoder(lstm_out)
        feature = self.fc(transformer_out[:, -1, :])  # 取最后一个时间步
        return feature

# 初始化模型
ecg_model = ECGTransformer()

🔥 载入 ECG 数据

import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler

# 载入 ECG 数据
ecg_data = pd.read_csv("ecg_arrythmia.csv")  # ECG 数据集
scaler = MinMaxScaler()
ecg_signals = scaler.fit_transform(ecg_data.values)  # 归一化

# 转换为 Tensor
ecg_tensor = torch.tensor(ecg_signals, dtype=torch.float32).unsqueeze(0)  # 添加 batch 维度

📌 4. 电子病历文本特征提取(BERT)

患者病历文本提供额外的信息,如 病史、症状、家族病史

from transformers import BertTokenizer, BertModel

# 加载 BERT 预训练模型
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")

def extract_text_features(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    with torch.no_grad():
        features = bert_model(**inputs).pooler_output  # 取池化输出
    return features

🔥 载入病人症状

text_data = "Patient experiences frequent heart palpitations, irregular heartbeats, and dizziness."
text_features = extract_text_features(text_data)

📌 5. 多模态融合(ECG + 电子病历文本)

使用**全连接网络(FC)**融合 ECG 和文本信息,并进行 心律失常预测

class MultiModalArrhythmiaModel(nn.Module):
    def __init__(self, ecg_feature_dim=128, text_feature_dim=768, num_classes=2):
        super().__init__()
        self.fc1 = nn.Linear(ecg_feature_dim + text_feature_dim, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, num_classes)  # 2 类:正常 / 心律失常

    def forward(self, ecg_features, text_features):
        x = torch.cat((ecg_features, text_features), dim=1)  # 拼接特征
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化多模态 AI
multi_modal_model = MultiModalArrhythmiaModel()

📌 6. 进行心律失常预测

# 提取 ECG 特征
with torch.no_grad():
    ecg_features = ecg_model(ecg_tensor)

# 预测心律失常
with torch.no_grad():
    output = multi_modal_model(ecg_features, text_features)
    prediction = torch.argmax(output).item()

print("心律失常预测:", "高风险" if prediction == 1 else "正常")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值