import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, roc_curve
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
import shap
import warnings
warnings.filterwarnings('ignore')
# 设置中文显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
plt.rcParams['axes.unicode_minus'] = False
# 1. 数据加载与预处理
def load_and_preprocess(disease):
if disease == "stroke":
df = pd.read_csv("stroke.csv")
# 特征选择
features = ['age', 'hypertension', 'heart_disease', 'avg_glucose_level', 'bmi', 'smoking_status']
target = 'stroke'
# 处理缺失值
df['bmi'].fillna(df['bmi'].median(), inplace=True)
# 分类变量编码
cat_features = ['smoking_status']
elif disease == "heart":
df = pd.read_csv("heart.csv")
features = ['Age', 'ChestPainType', 'Cholesterol', 'MaxHR', 'ExerciseAngina', 'ST_Slope']
target = 'HeartDisease'
# 处理缺失值
df['Cholesterol'].replace(0, df['Cholesterol'].median(), inplace=True)
cat_features = ['ChestPainType', 'ExerciseAngina', 'ST_Slope']
elif disease == "cirrhosis":
df = pd.read_csv("cirrhosis.csv")
features = ['Bilirubin', 'Albumin', 'Prothrombin', 'Edema', 'Stage', 'Platelets']
target = 'Status' # 假设Status中'D'表示患病(需根据实际数据调整)
# 处理缺失值
for col in features:
if df[col].dtype == 'float64':
df[col].fillna(df[col].median(), inplace=True)
else:
df[col].fillna(df[col].mode()[0], inplace=True)
# 目标变量处理(将状态转为二分类:患病=1,否则=0)
df[target] = df[target].apply(lambda x: 1 if x == 'D' else 0)
cat_features = ['Edema', 'Stage']
# 划分特征与目标
X = df[features]
y = df[target]
# 编码分类特征
preprocessor = ColumnTransformer(
transformers=[('cat', OneHotEncoder(drop='first'), cat_features)],
remainder='passthrough'
)
X_processed = preprocessor.fit_transform(X)
# 划分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(X_processed, y, test_size=0.2, random_state=42)
return X_train, X_test, y_train, y_test, preprocessor, features
# 2. 训练随机森林模型
def train_rf(X_train, y_train):
model = RandomForestClassifier(n_estimators=100, max_depth=8, random_state=42)
model.fit(X_train, y_train)
return model
# 3. 模型评估
def evaluate_model(model, X_test, y_test, disease):
y_pred = model.predict(X_test)
y_prob = model.predict_proba(X_test)[:, 1]
# 评估指标
metrics = {
'准确率': accuracy_score(y_test, y_pred),
'精确率': precision_score(y_test, y_pred),
'召回率': recall_score(y_test, y_pred),
'F1分数': f1_score(y_test, y_pred),
'AUC': roc_auc_score(y_test, y_prob)
}
print(f"\n{disease}模型评估指标:")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
# 混淆矩阵可视化
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['未患病', '患病'],
yticklabels=['未患病', '患病'])
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title(f'{disease}模型混淆矩阵')
plt.show()
# ROC曲线
fpr, tpr, _ = roc_curve(y_test, y_prob)
plt.figure(figsize=(6, 4))
plt.plot(fpr, tpr, label=f'AUC = {metrics["AUC"]:.4f}')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('假阳性率')
plt.ylabel('真阳性率')
plt.title(f'{disease}模型ROC曲线')
plt.legend()
plt.show()
return y_pred, y_prob
# 4. SHAP灵敏度分析
import matplotlib.pyplot as plt
import numpy as np
import shap
def shap_analysis(model, X_train, X_test, features, disease, preprocessor):
# 提取特征名称(处理编码后的特征)
# 注意:确保preprocessor是已拟合的,且X_test是预处理后的数据
cat_features = [f for f in features if f in preprocessor.transformers_[0][2]]
num_features = [f for f in features if f not in cat_features]
ohe = preprocessor.named_transformers_['cat']
cat_names = ohe.get_feature_names_out(cat_features) # 独热编码后的特征名
feature_names = list(cat_names) + num_features # 合并类别特征和数值特征名
# SHAP值计算(关键修正:处理分类模型的多类别结构)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
# 分类模型需指定目标类别(如二分类取正类,多分类按需选择)
# 若模型是回归模型,shap_values是二维数组,无需此步骤
if isinstance(shap_values, list): # 分类模型返回列表(每个元素对应一个类别)
# 假设关注第一个类别(可根据实际需求调整索引,如二分类常用1表示正类)
target_class = 1 if len(shap_values) == 2 else 0 # 示例逻辑,按需修改
shap_values = shap_values[target_class]
# 验证特征维度匹配(避免因预处理导致特征数不一致)
assert X_test.shape[1] == len(feature_names), \
f"特征维度不匹配:X_test有{X_test.shape[1]}列,feature_names有{len(feature_names)}个名称"
# 摘要图(全局特征重要性)
plt.figure(figsize=(10, 6))
# 修正:明确传递特征名,并确保shap_values维度为(n_samples, n_features)
shap.summary_plot(shap_values, X_test, feature_names=feature_names, plot_type="bar")
plt.title(f'{disease}模型SHAP特征重要性')
plt.tight_layout() # 防止标签重叠导致显示不全
plt.show()
# 计算前3个最敏感特征
shap_sum = np.abs(shap_values).mean(0) # 各特征SHAP绝对值的均值
top3_idx = np.argsort(shap_sum)[-3:] # 取重要性最大的3个特征索引(升序排序后取最后3个)
assert np.issubdtype(top3_idx.dtype, np.integer), "Indices must be integers"
top3_idx = np.array(top3_idx, dtype=int).flatten()
top3_features = [feature_names[i] for i in top3_idx[::-1]] # 反转索引以按重要性降序排列
print(f"{disease}模型最敏感的3个因素(降序):{top3_features}")
return top3_features
##def shap_analysis(model, X_train, X_test, features, disease, preprocessor):
## # 提取特征名称(处理编码后的特征)
## cat_features = [f for f in features if f in preprocessor.transformers_[0][2]]
## num_features = [f for f in features if f not in cat_features]
## ohe = preprocessor.named_transformers_['cat']
## cat_names = ohe.get_feature_names_out(cat_features)
## feature_names = list(cat_names) + num_features
##
## # SHAP值计算
## explainer = shap.TreeExplainer(model)
## shap_values = explainer.shap_values(X_test)
##
## # 摘要图(全局特征重要性)
## plt.figure(figsize=(10, 6))
## shap.summary_plot(shap_values, X_test, feature_names=feature_names, plot_type="bar")
## plt.title(f'{disease}模型SHAP特征重要性')
## plt.show()
##
## # 前3个最敏感特征
## shap_sum = np.abs(shap_values).mean(0)
## assert np.issubdtype(top3_idx.dtype, np.integer), "Indices must be integers"
## top3_idx = np.array(top3_idx, dtype=int).flatten()
## top3_features = [feature_names[i] for i in top3_idx]
## print(f"{disease}模型最敏感的3个因素:{top3_features}")
##
## return top3_features
# 5. 主函数
def main():
diseases = ["stroke", "heart", "cirrhosis"]
results = {}
for disease in diseases:
print(f"\n===== {disease}预测模型 =====")
X_train, X_test, y_train, y_test, preprocessor, features = load_and_preprocess(disease)
model = train_rf(X_train, y_train)
y_pred, y_prob = evaluate_model(model, X_test, y_test, disease)
top3 = shap_analysis(model, X_train, X_test, features, disease, preprocessor)
# 保存预测结果(示例:前10个样本)
results[disease] = {
'预测值(前10)': y_pred[:10],
'预测概率(前10)': y_prob[:10],
'最敏感因素': top3
}
return results
if __name__ == "__main__":
results = main()
解释以上代码为何在绘制SHAP摘要图时,有横纵坐标但是除坐标外图中没有任何图,折线,并给出对相应代码的改进措施以及修正修改过后的代码
最新发布