C:\ProgramData\anaconda3\python.exe "C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py"
918条数据的年龄分箱分布:
Age_bin
0 0
1 80
2 585
3 253
4 0
Name: count, dtype: int64
5110条数据的年龄分箱分布:
Age_bin
0 966
1 1204
2 1564
3 1190
4 186
Name: count, dtype: int64
412条数据的年龄分箱分布:
Age_bin
0 0
1 69
2 257
3 86
4 0
Name: count, dtype: int64
胆固醇分箱分布(标签0-1-2):
Cholesterol_bin
0 318
1 237
2 363
Name: count, dtype: int64
胆固醇分箱分布(标签0-1-2):
Cholesterol_bin
0 23
1 39
2 252
Name: count, dtype: int64
Sex ChestPainType FastingBS RestingECG ExerciseAngina ST_Slope HeartDisease RestingBP_bin MaxHR_bin Oldpeak_bin Age_bin Cholesterol_bin
0 M ATA 0 Normal N Up 0 1级高血压 高(85-100%) 无压低 2 2
1 F NAP 0 Normal N Flat 1 2级高血压 高(85-100%) 轻度压低 2 0
2 M ATA 0 ST N Up 0 升高前期 低(<60%) 无压低 1 2
3 F ASY 0 Normal Y Flat 1 1级高血压 中(60-85%) 中度压低 2 1
4 M NAP 0 Normal N Up 0 2级高血压 中(60-85%) 无压低 2 0
.. .. ... ... ... ... ... ... ... ... ... ... ...
913 M TA 0 Normal N Flat 1 正常 中(60-85%) 中度压低 2 2
914 M ASY 1 Normal N Flat 1 2级高血压 高(85-100%) 严重压低 3 0
915 M ASY 0 Normal Y Flat 1 升高前期 中(60-85%) 中度压低 2 0
916 F ATA 0 LVH N Flat 1 升高前期 极高(>100%) 无压低 2 1
917 M NAP 0 Normal N Up 0 1级高血压 高(85-100%) 无压低 1 0
[918 rows x 12 columns]
ID N_Days Status Drug Sex Ascites Hepatomegaly Spiders Edema Bilirubin Albumin Stage Bilirubin_bin Albumin_bin Copper_bin Alk_Phos_bin SGOT_bin Tryglicerides_bin Platelets_bin Prothrombin_bin Age_bin Cholesterol_bin
0 1 400 D D-penicillamine F Y Y Y Y 14.5 2.60 4 显著升高 降低 升高 升高 升高 升高 正常 正常 2 2
1 2 4500 C D-penicillamine F N Y Y N 1.1 4.14 3 正常 正常 升高 升高 升高 正常 正常 正常 2 2
2 3 1012 D D-penicillamine M N N N S 1.4 3.48 4 轻度升高 降低 升高 升高 升高 正常 正常 正常 3 0
3 4 1925 D D-penicillamine F N Y Y S 1.8 2.54 4 轻度升高 降低 升高 升高 升高 正常 正常 正常 2 2
4 5 1504 CL Placebo F N Y Y N 3.4 3.53 3 显著升高 正常 升高 升高 升高 正常 减少 正常 1 2
.. ... ... ... ... .. ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
407 414 681 D D-penicillamine F N Y N N 1.2 2.96 3 正常 降低 升高 升高 正常 升高 正常 正常 3 NaN
408 415 1103 C D-penicillamine F N Y N N 0.9 3.83 4 正常 正常 升高 升高 正常 升高 正常 正常 1 NaN
409 416 1055 C D-penicillamine F N Y N N 1.6 3.42 3 轻度升高 降低 升高 升高 正常 升高 减少 正常 2 NaN
410 417 691 C D-penicillamine F N Y N N 0.8 3.75 3 正常 正常 升高 升高 正常 升高 正常 正常 2 NaN
411 418 976 C D-penicillamine F N Y N N 0.7 3.29 4 正常 降低 升高 升高 正常 升高 正常 正常 2 NaN
[412 rows x 22 columns]
id Sex hypertension heart_disease ever_married work_type Residence_type smoking_status stroke glucose_bin Age_bin bmi_bin
0 9046 M 0 1 Yes Private Urban formerly smoked 1 糖尿病 3 肥胖
1 51676 F 0 0 Yes Self-employed Rural never smoked 1 糖尿病 3 偏瘦
2 31112 M 0 1 Yes Private Rural never smoked 1 偏高 4 肥胖
3 60182 F 0 0 Yes Private Urban smokes 1 糖尿病前期 2 肥胖
4 1665 F 1 0 Yes Self-employed Rural never smoked 1 糖尿病前期 3 正常
... ... .. ... ... ... ... ... ... ... ... ... ...
5105 18234 F 1 0 Yes Private Urban never smoked 0 正常 4 NaN
5106 44873 F 0 0 Yes Self-employed Urban never smoked 0 偏高 4 肥胖
5107 19723 F 0 0 Yes Self-employed Rural never smoked 0 正常 1 肥胖
5108 37544 M 0 0 Yes Private Rural formerly smoked 0 糖尿病前期 2 超重
5109 44679 F 0 0 Yes Govt_job Urban Unknown 0 正常 2 超重
[5110 rows x 12 columns]
处理后的数据形状:(918, 12)(无NaN)
处理后的数据形状:(5110, 12)(无NaN)
处理后的数据形状:(412, 22)(无NaN)
Main already
处理 heart 数据集中的稀有类别(目标变量: HeartDisease)
处理 stroke 数据集中的稀有类别(目标变量: stroke)
处理 cirrhosis 数据集中的稀有类别(目标变量: Stage)
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:271: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
df['Oldpeak_bin'] = df['Oldpeak_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:271: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead.
df['Oldpeak_bin'] = df['Oldpeak_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:275: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
df['RestingBP_bin'] = df['RestingBP_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:275: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead.
df['RestingBP_bin'] = df['RestingBP_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:279: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
df['ST_Slope'] = df['ST_Slope'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:283: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
df['ChestPainType'] = df['ChestPainType'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:290: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
df['glucose_bin'] = df['glucose_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:290: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead.
df['glucose_bin'] = df['glucose_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:293: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
df['bmi_bin'] = df['bmi_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:293: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead.
df['bmi_bin'] = df['bmi_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:296: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
df['smoking_status'] = df['smoking_status'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:303: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
df['Bilirubin_bin'] = df['Bilirubin_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:303: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead.
df['Bilirubin_bin'] = df['Bilirubin_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:306: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`
df['Albumin_bin'] = df['Albumin_bin'].replace(
C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:306: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead.
df['Albumin_bin'] = df['Albumin_bin'].replace(
Config already
已为 heart 构建贝叶斯网络
[heart] 模型CPD验证:
节点 ChestPainType 的CPD:
+------------------+----------+
| ChestPainType(0) | 0.542339 |
+------------------+----------+
| ChestPainType(1) | 0.190188 |
+------------------+----------+
| ChestPainType(2) | 0.221102 |
+------------------+----------+
| ChestPainType(3) | 0.046371 |
+------------------+----------+
节点 HeartDisease 的CPD:
+-----------------+-----+---------------------+
| Age_bin | ... | Age_bin(3) |
+-----------------+-----+---------------------+
| ChestPainType | ... | ChestPainType(3) |
+-----------------+-----+---------------------+
| RestingBP_bin | ... | RestingBP_bin(3) |
+-----------------+-----+---------------------+
| ST_Slope | ... | ST_Slope(2) |
+-----------------+-----+---------------------+
| HeartDisease(0) | ... | 0.6628959276018099 |
+-----------------+-----+---------------------+
| HeartDisease(1) | ... | 0.33710407239819007 |
+-----------------+-----+---------------------+
节点 ST_Slope 的CPD:
+-------------+-----------+
| ST_Slope(0) | 0.0689964 |
+-------------+-----------+
| ST_Slope(1) | 0.513889 |
+-------------+-----------+
| ST_Slope(2) | 0.417115 |
+-------------+-----------+
已为 stroke 构建贝叶斯网络
[stroke] 模型CPD验证:
节点 hypertension 的CPD:
+-----------------+-----------+
| hypertension(0) | 0.905076 |
+-----------------+-----------+
| hypertension(1) | 0.0949244 |
+-----------------+-----------+
节点 stroke 的CPD:
+---------------+-----+------------------+
| Age_bin | ... | Age_bin(4) |
+---------------+-----+------------------+
| glucose_bin | ... | glucose_bin(3) |
+---------------+-----+------------------+
| heart_disease | ... | heart_disease(1) |
+---------------+-----+------------------+
| hypertension | ... | hypertension(1) |
+---------------+-----+------------------+
| stroke(0) | ... | 0.5 |
+---------------+-----+------------------+
| stroke(1) | ... | 0.5 |
+---------------+-----+------------------+
节点 glucose_bin 的CPD:
+----------------+-----------+
| glucose_bin(0) | 0.60798 |
+----------------+-----------+
| glucose_bin(1) | 0.227794 |
+----------------+-----------+
| glucose_bin(2) | 0.0777208 |
+----------------+-----------+
| glucose_bin(3) | 0.0865056 |
+----------------+-----------+
已为 cirrhosis 构建贝叶斯网络
[cirrhosis] 模型CPD验证:
节点 Bilirubin_bin 的CPD:
+------------------+----------+
| Bilirubin_bin(0) | 0.446411 |
+------------------+----------+
| Bilirubin_bin(1) | 0.26647 |
+------------------+----------+
| Bilirubin_bin(2) | 0.287119 |
+------------------+----------+
节点 Stage 的CPD:
+---------------+-----+------------------+
| Age_bin | ... | Age_bin(3) |
+---------------+-----+------------------+
| Albumin_bin | ... | Albumin_bin(1) |
+---------------+-----+------------------+
| Ascites | ... | Ascites(Y) |
+---------------+-----+------------------+
| Bilirubin_bin | ... | Bilirubin_bin(2) |
+---------------+-----+------------------+
| Stage(1) | ... | 0.25 |
+---------------+-----+------------------+
| Stage(2) | ... | 0.25 |
+---------------+-----+------------------+
| Stage(3) | ... | 0.25 |
+---------------+-----+------------------+
| Stage(4) | ... | 0.25 |
+---------------+-----+------------------+
节点 Albumin_bin 的CPD:
+----------------+----------+
| Albumin_bin(0) | 0.477876 |
+----------------+----------+
| Albumin_bin(1) | 0.522124 |
+----------------+----------+
已成功保存 heart 网络结构图
已成功保存 stroke 网络结构图
已成功保存 cirrhosis 网络结构图
[heart] 拟合度计算失败: Missing columns in data. Can't find values for the following variables: set()
[heart] 推理失败样本过多(184条),无有效预测结果
[stroke] 拟合度计算失败: Missing columns in data. Can't find values for the following variables: set()
[stroke] 推理失败样本过多(1022条),无有效预测结果
[cirrhosis] 拟合度计算失败: Missing columns in data. Can't find values for the following variables: set()
[cirrhosis] 推理失败样本过多(83条),无有效预测结果
进程已结束,退出代码为 0
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pgmpy.models import BayesianNetwork
from pgmpy.estimators import MaximumLikelihoodEstimator
from pgmpy.inference import VariableElimination
import networkx as nx
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from pgmpy.metrics import log_likelihood_score
from pgmpy.estimators import BicScore
from sklearn.model_selection import train_test_split
from pgmpy.estimators import BayesianEstimator
from sympy.solvers.diophantine.diophantine import equivalent
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
df_heart = pd.read_csv('heart_se.csv')
df_stroke = pd.read_csv('stroke_se.csv')
df_cirrhosis = pd.read_csv('cirrhosis_se.csv')
df_stroke.columns = ['id', 'Sex', 'Age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status', 'stroke']
df_stroke['Sex'] = df_stroke['Sex'].apply(lambda x: 'M' if x == 'Male' else 'F')
df_cirrhosis['Age'] = (df_cirrhosis['Age'] / 365).astype(int)
df_stroke['Age'] = df_stroke['Age'].astype(int)
def preprocess_cirrhosis(df):
"""处理肝硬化数据集中的高基数特征,转为低基数类别变量"""
# 1. 胆固醇(Cholesterol):医学标准分箱(mg/dL)
# 正常:<200;边缘升高:200-239;升高:≥240(参考临床标准)
# df['Cholesterol_bin'] = pd.cut(
# df['Cholesterol'],
# bins=[0, 200, 240, 1000],
# labels=['正常', '临界高值', '高胆固醇血症']
# )
# 2. 胆红素(Bilirubin):反映肝脏排泄功能(mg/dL)
# 正常:<1.2;轻度升高:1.2-3.0;显著升高:>3.0
df['Bilirubin_bin'] = pd.cut(
df['Bilirubin'],
bins=[-np.inf, 1.2, 3.0, np.inf],
labels=['正常', '轻度升高', '显著升高']
)
# 3. 白蛋白(Albumin):肝脏合成功能指标(g/dL)
# 正常:3.5-5.0;降低:<3.5(肝硬化典型表现)
df['Albumin_bin'] = pd.cut(
df['Albumin'],
bins=[-np.inf, 3.5, 5.0, np.inf],
labels=['降低', '正常', '升高'] # 升高临床意义较小,合并为一类
)
# 4. 铜(Copper):尿铜排泄(μg/天)
# 正常:<50;升高:≥50(肝豆状核变性指标)
df['Copper_bin'] = pd.cut(
df['Copper'],
bins=[-np.inf, 50, np.inf],
labels=['正常', '升高']
)
# 5. 碱性磷酸酶(Alk_Phos):肝胆疾病指标(单位/升)
# 正常范围:40-150;升高:>150
df['Alk_Phos_bin'] = pd.cut(
df['Alk_Phos'],
bins=[-np.inf, 150, np.inf],
labels=['正常', '升高']
)
# 6. 转氨酶(SGOT):肝细胞损伤指标(单位/毫升)
# 正常:<40;升高:≥40
df['SGOT_bin'] = pd.cut(
df['SGOT'],
bins=[-np.inf, 40, np.inf],
labels=['正常', '升高']
)
# 7. 甘油三酯(Tryglicerides):血脂指标(mg/dL)
# 正常:<150;升高:≥150
df['Tryglicerides_bin'] = pd.cut(
df['Tryglicerides'],
bins=[-np.inf, 150, np.inf],
labels=['正常', '升高']
)
# 8. 血小板(Platelets):肝硬化脾功能亢进指标(×10^9/L)
# 正常:150-450;减少:<150
df['Platelets_bin'] = pd.cut(
df['Platelets'],
bins=[-np.inf, 150, np.inf],
labels=['减少', '正常']
)
# 9. 凝血酶原时间(Prothrombin):肝脏合成功能(秒)
# 正常:11-13;延长:>13
df['Prothrombin_bin'] = pd.cut(
df['Prothrombin'],
bins=[-np.inf, 13, np.inf],
labels=['正常', '延长']
)
# 移除原始高基数变量,保留离散化后的变量
df = df.drop(
columns=['Copper', 'Alk_Phos', 'SGOT',
'Tryglicerides', 'Platelets', 'Prothrombin'],
errors='ignore'
)
return df
df_stroke['glucose_bin'] = pd.cut(
df_stroke['avg_glucose_level'],
bins=[0, 100, 140, 200, 1000],
labels=['正常', '偏高', '糖尿病前期', '糖尿病']
)
# 预处理cirrhosis(含离散化+类型转换)
df_cirrhosis = preprocess_cirrhosis(df_cirrhosis)
def preprocess_heart(df):
"""处理心脏病数据集中的高基数特征,转为低基数类别变量"""
# 2. 静息血压(RestingBP):按高血压指南分箱
# 正常:<120;升高前期:120-129;1级高血压:130-139;2级高血压:≥140
df['RestingBP_bin'] = pd.cut(
df['RestingBP'],
bins=[0, 120, 130, 140, 300],
labels=['正常', '升高前期', '1级高血压', '2级高血压']
)
# 3. 胆固醇(Cholesterol):按血脂异常标准分箱
# 正常:<200;临界高值:200-239;高胆固醇血症:≥240
# df['Cholesterol_bin'] = pd.cut(
# df['Cholesterol'],
# bins=[0, 200, 240, 1000],
# labels=['正常', '临界高值', '高胆固醇血症']
# )
# 4. 最大心率(MaxHR):按年龄预测最大心率百分比分箱
# 计算预测最大心率 = 220 - 年龄
df['MaxHR_pct'] = df['MaxHR'] / (220 - df['Age'])
df['MaxHR_bin'] = pd.cut(
df['MaxHR_pct'],
bins=[0, 0.6, 0.85, 1.0, 2.0],
labels=['低(<60%)', '中(60-85%)', '高(85-100%)', '极高(>100%)']
)
df.drop(columns=['MaxHR_pct'], inplace=True) # 删除临时计算列
# 5. ST段压低(Oldpeak):反映心肌缺血程度
# 无压低:≤0;轻度压低:0-1;中度压低:1-2;严重压低:>2
df['Oldpeak_bin'] = pd.cut(
df['Oldpeak'],
bins=[-0.1, 0, 1, 2, 10.0],
labels=['无压低', '轻度压低', '中度压低', '严重压低']
)
# 移除原始高基数变量,保留离散化后的变量
df = df.drop(
columns=['RestingBP','MaxHR', 'Oldpeak'],
errors='ignore'
)
return df
df_heart = preprocess_heart(df_heart)
# 将目标变量转为分类类型(确保pgmpy能处理)
df_heart['HeartDisease'] = df_heart['HeartDisease'].astype('category')
df_stroke['stroke'] = df_stroke['stroke'].astype('category')
df_cirrhosis['Stage'] = df_cirrhosis['Stage'].astype('category')
def discretize_age_uniform(df, age_col='Age'):
"""
为三个数据集统一年龄分箱,用0-4数字标签,覆盖0.05~82岁所有范围
分箱逻辑:按20年一个区间,兼顾医学年龄分组和数据覆盖
标签:0-4(数字越小表示年龄越小)
"""
df['Age_bin'] = pd.cut(
df[age_col],
bins=[-0.1, 19, 39, 59, 79, 120], # 覆盖0.05~82岁所有数据
labels=[0, 1, 2, 3, 4], # 数字标签(0:儿童青少年;1:青年;2:中年;3:老年;4:高龄)
include_lowest=True # 包含最小值(0.05岁)
)
# 检查分箱后的数据分布(确保无空区间)
print(f"{df.shape[0]}条数据的年龄分箱分布:")
print(df['Age_bin'].value_counts().sort_index()) # 按0-4排序输出
return df
df_heart = discretize_age_uniform(df_heart) # 心脏病(28~77岁)
df_stroke = discretize_age_uniform(df_stroke) # 中风(0.05~82岁)
df_cirrhosis = discretize_age_uniform(df_cirrhosis) # 肝硬化(26~78岁)
# 分箱后删除原始Age列,保留分箱列(数字标签)
df_heart.drop(columns=['Age'], inplace=True)
df_stroke.drop(columns=['Age'], inplace=True)
df_cirrhosis.drop(columns=['Age'], inplace=True)
def discretize_cholesterol(df, cholesterol_col='Cholesterol'):
"""
胆固醇分箱:统一标准,标签0-1-2,覆盖两个数据集的胆固醇范围
分箱逻辑(医学标准):
- 0: 正常(<200 mg/dL)
- 1: 临界高值(200-239 mg/dL)
- 2: 升高(≥240 mg/dL)
"""
if cholesterol_col not in df.columns:
print("数据集中无胆固醇列,跳过分箱")
return df
df['Cholesterol_bin'] = pd.cut(
df[cholesterol_col],
bins=[-1, 199, 239, 10000], # 覆盖两个数据集的胆固醇范围(最高达3e6,设10000足够)
labels=[0, 1, 2], # 数字标签(0-正常,1-临界,2-升高)
include_lowest=True
)
# 验证分箱分布(确保无空箱)
print(f"胆固醇分箱分布(标签0-1-2):")
print(df['Cholesterol_bin'].value_counts().sort_index()) # 按0-1-2排序
# 删除原始胆固醇列,保留分箱列
df.drop(columns=[cholesterol_col], inplace=True, errors='ignore')
return df
df_heart = discretize_cholesterol(df_heart, cholesterol_col='Cholesterol')
df_cirrhosis = discretize_cholesterol(df_cirrhosis, cholesterol_col='Cholesterol')
def discretize_stroke_continuous(df):
# bmi分箱(按WHO标准)
df['bmi_bin'] = pd.cut(
df['bmi'],
bins=[-1, 18.5, 24, 28, 100],
labels=['偏瘦', '正常', '超重', '肥胖']
)
# 血糖分箱(复用之前的逻辑)
df['glucose_bin'] = pd.cut(
df['avg_glucose_level'],
bins=[0, 100, 140, 200, 1000],
labels=['正常', '偏高', '糖尿病前期', '糖尿病']
)
df.drop(columns=['bmi', 'avg_glucose_level'], inplace=True)
return df
df_stroke = discretize_stroke_continuous(df_stroke)
# inner_join = pd.merge(df_heart, df_stroke, on=['Age'])
print(df_heart)
print(df_cirrhosis)
print(df_stroke)
# 在所有预处理后添加强制去NaN
def drop_or_fill_nan(df):
# 先尝试填充分类变量的NaN(用众数)
for col in df.columns:
if df[col].isna().any():
if df[col].dtype == 'category' or df[col].dtype == 'object':
df[col] = df[col].fillna(df[col].mode()[0])
else:
df[col] = df[col].fillna(df[col].median())
# 若仍有NaN,直接删除(确保数据干净)
df = df.dropna()
print(f"处理后的数据形状:{df.shape}(无NaN)")
return df
df_heart = drop_or_fill_nan(df_heart)
df_stroke = drop_or_fill_nan(df_stroke)
df_cirrhosis = drop_or_fill_nan(df_cirrhosis)
# 心脏病:替换Oldpeak_bin、RestingBP_bin等中文标签
def replace_heart_labels(df):
# Oldpeak_bin:0-3(无→轻→中→重)
df['Oldpeak_bin'] = df['Oldpeak_bin'].replace(
{'无压低':0, '轻度压低':1, '中度压低':2, '严重压低':3}
).astype(int)
# RestingBP_bin:0-3(正常→升高前期→1级→2级)
df['RestingBP_bin'] = df['RestingBP_bin'].replace(
{'正常':0, '升高前期':1, '1级高血压':2, '2级高血压':3}
).astype(int)
# ST_Slope:0-2(Down→Flat→Up)
df['ST_Slope'] = df['ST_Slope'].replace(
{'Down':0, 'Flat':1, 'Up':2}
).astype(int)
# ChestPainType:0-3(ASY→ATA→NAP→TA)
df['ChestPainType'] = df['ChestPainType'].replace(
{'ASY':0, 'ATA':1, 'NAP':2, 'TA':3}
).astype(int)
return df
# 中风:替换glucose_bin、bmi_bin等
def replace_stroke_labels(df):
df['glucose_bin'] = df['glucose_bin'].replace(
{'正常':0, '偏高':1, '糖尿病前期':2, '糖尿病':3}
).astype(int)
df['bmi_bin'] = df['bmi_bin'].replace(
{'偏瘦':0, '正常':1, '超重':2, '肥胖':3}
).astype(int)
df['smoking_status'] = df['smoking_status'].replace(
{'never smoked':0, 'formerly smoked':1, 'smokes':2, 'Unknown':3}
).astype(int)
return df
# 肝硬化:替换Bilirubin_bin等
def replace_cirrhosis_labels(df):
df['Bilirubin_bin'] = df['Bilirubin_bin'].replace(
{'正常':0, '轻度升高':1, '显著升高':2}
).astype(int)
df['Albumin_bin'] = df['Albumin_bin'].replace(
{'降低':0, '正常':1, '升高':2}
).astype(int)
# 其他标签类似替换为0-1-2
return df
# 应用替换
df_heart = replace_heart_labels(df_heart)
df_stroke = replace_stroke_labels(df_stroke)
df_cirrhosis = replace_cirrhosis_labels(df_cirrhosis)
#----------------------------------------------------------------------------------------------------
def build_bayesian_network(data, file_name, shared_variables=['age', 'gender'], # 保持参数名一致
local_shared_vars=None, # 新增局部共享变量参数
included_features=None, custom_edges=None, forbidden_edges=None):
"""为单个数据集构建贝叶斯网络模型,支持多个共享变量和自定义先验知识"""
if data is None or data.empty:
print(f"无有效数据,无法为 {file_name} 构建模型")
return None
# 将共享变量名统一转为小写(便于匹配)
shared_variables = [var.lower() for var in shared_variables]
# 如果指定了要包含的特征,则筛选数据
if included_features:
# 确保所有共享变量被包含
for var in shared_variables:
# 查找数据中匹配的列名(不区分大小写)
matched_col = next((col for col in data.columns if col.lower() == var), None)
if matched_col and matched_col not in included_features:
included_features.append(matched_col)
data = data[included_features]
# 定义网络结构
edges = []
# 为每个共享变量添加边
# for shared_var in shared_variables:
# # 查找数据中匹配的列名
# matched_col = next((col for col in data.columns if col.lower() == shared_var), None)
# if matched_col:
# for feature in data.columns:
# if feature != matched_col:
# edges.append((matched_col, feature)) # 共享变量→其他特征
# 原代码中共享变量生成边的部分修改为:
for shared_var in shared_variables:
matched_col = next((col for col in data.columns if col.lower() == shared_var), None)
if matched_col:
# 目标变量匹配时忽略大小写
target_vars = [col for col in data.columns
if col.lower() in ['stroke', 'heartdisease']] # 修正为小写匹配
for feature in target_vars:
if feature != matched_col and (matched_col, feature) not in edges:
edges.append((matched_col, feature)) # 强制生成共享变量→目标变量的边
# 添加自定义边(领域知识)
if custom_edges:
edges.extend(custom_edges)
# 移除禁止的边(领域知识)
if forbidden_edges:
edges = [edge for edge in edges if edge not in forbidden_edges]
# 创建贝叶斯网络
model = BayesianNetwork(edges)
# 使用最大似然估计器学习参数
model.fit(data, estimator=BayesianEstimator, prior_type = "BDeu", equivalent_sample_size = 10)
print(f"已为 {file_name} 构建贝叶斯网络")
return model
# 可视化贝叶斯网络
def visualize_network(model, file_name):
"""
可视化贝叶斯网络结构,增强鲁棒性处理各种布局异常
"""
if model is None:
print(f"警告: {file_name} 模型为空,无法可视化")
return
# 创建图的副本并转为无向图(简化布局计算)
G = model.to_undirected() if hasattr(model, 'to_undirected') else model.copy()
# 严格过滤无效边
valid_edges = []
for u, v in G.edges():
if u != v: # 排除自循环边
valid_edges.append((u, v))
# 清除无效边并检查是否还有剩余边
G.remove_edges_from(list(G.edges()))
G.add_edges_from(valid_edges)
if not G.edges():
print(f"错误: {file_name} 网络中没有有效边(可能所有边都是自循环),无法可视化")
return
# 尝试多种布局算法,按优先级选择
try:
# 优先使用力导向布局(适合稀疏图)
pos = nx.spring_layout(G, k=0.3, iterations=100, seed=42)
except:
try:
# 备选:Kawai布局(适合密集图)
pos = nx.kamada_kawai_layout(G, seed=42)
except:
try:
# 备选:圆形布局
pos = nx.circular_layout(G)
except:
# 最后的备选:随机布局
pos = nx.random_layout(G, seed=42)
print(f"警告: {file_name} 网络使用随机布局,可能效果不佳")
# 检查节点位置是否有重叠(距离小于阈值)
MIN_DISTANCE = 0.01
positions = list(pos.values())
has_overlap = False
for i in range(len(positions)):
for j in range(i + 1, len(positions)):
dist = np.sqrt(((positions[i][0] - positions[j][0]) ** 2) +
((positions[i][1] - positions[j][1]) ** 2))
if dist < MIN_DISTANCE:
has_overlap = True
break
if has_overlap:
break
# 如果有重叠,添加微小扰动
if has_overlap:
print(f"警告: {file_name} 网络节点位置存在重叠,添加随机扰动")
for node in pos:
pos[node] = (
pos[node][0] + np.random.uniform(-0.02, 0.02),
pos[node][1] + np.random.uniform(-0.02, 0.02)
)
# 创建图形
plt.figure(figsize=(14, 12)) # 增大图形尺寸
# 绘制节点
nx.draw_networkx_nodes(
G, pos,
node_size=2800, # 增大节点尺寸
node_color='skyblue',
alpha=0.8,
edgecolors='black', # 添加节点边框
linewidths=1.0
)
# 绘制边(分批次处理,捕获异常)
edges_to_draw = list(G.edges())
successful_edges = []
for edge in edges_to_draw:
try:
nx.draw_networkx_edges(
G, pos,
edgelist=[edge],
arrows=True,
arrowsize=20,
width=1.5,
alpha=0.7,
edge_color='gray'
)
successful_edges.append(edge)
except Exception as e:
print(f"警告: 无法绘制边 {edge}: {e}")
if len(successful_edges) == 0:
print(f"错误: {file_name} 网络中没有成功绘制的边")
plt.close()
return
# 绘制节点标签
nx.draw_networkx_labels(
G, pos,
font_size=10,
font_weight='bold',
font_family='SimHei', # 确保中文显示
horizontalalignment='center',
verticalalignment='center'
)
# 设置标题和保存图形
plt.title(f'{file_name} 的贝叶斯网络结构', fontsize=14)
plt.axis('off') # 关闭坐标轴
plt.tight_layout() # 调整布局
try:
plt.savefig(f'{file_name}_bayesian_network.png', dpi=300, bbox_inches='tight')
print(f"已成功保存 {file_name} 网络结构图")
except Exception as e:
print(f"保存图像时出错: {e}")
plt.close() # 关闭图形以释放内存
# 执行跨网络推理
def cross_network_inference(models, file_names, global_shared_vars=['age', 'gender'],
local_shared_vars=None, evidence=None):
"""基于共享变量进行跨网络推理,支持全局和局部共享变量"""
results = {}
if evidence is None:
evidence = {'age': 50, 'gender': 'Male'} # 默认全局证据值
for model, file_name in zip(models, file_names):
if model is None:
continue
# 映射全局证据中的变量名到模型中的实际列名
model_evidence = {}
for var, value in evidence.items():
if var.lower() in [v.lower() for v in global_shared_vars]:
matched_col = next((col for col in model.nodes() if col.lower() == var.lower()), None)
if matched_col:
model_evidence[matched_col] = value
# 添加局部共享变量的证据(如果有)
if local_shared_vars:
for local_var_name, local_var_col in local_shared_vars.items():
if local_var_name in evidence and local_var_col in model.nodes():
model_evidence[local_var_col] = evidence[local_var_name]
infer = VariableElimination(model)
# 对每个模型中的关键变量进行推理
key_variables = [var for var in model.nodes() if var not in model_evidence]
for var in key_variables:
try:
result = infer.query(variables=[var], evidence=model_evidence)
results[f"{file_name}_{var}"] = result
print(f"\n{file_name} 中,已知 {model_evidence} 时,{var} 的概率分布:")
print(result)
except Exception as e:
print(f"\n{file_name} 推理 {var} 时错误: {e}")
return results
def merge_rare_classes(df, target_variable, min_samples=2):
"""
合并样本数少于阈值的稀有类别
:param df: 数据集
:param target_variable: 目标变量列名
:param min_samples: 最小样本数阈值
:return: 处理后的数据集
"""
if target_variable not in df.columns:
return df
# 统计每个类别的样本数
class_counts = df[target_variable].value_counts()
rare_classes = class_counts[class_counts < min_samples].index.tolist()
if not rare_classes:
return df # 没有稀有类别,直接返回
print(f"检测到稀有类别: {rare_classes},样本数阈值: {min_samples}")
# 创建映射:稀有类别→合并到的目标类别
# 策略:合并到最接近的类别(基于类别名称的数值或字母顺序)
sorted_classes = sorted(df[target_variable].unique())
class_mapping = {}
for rare_class in rare_classes:
# 找到最接近的有效类别
idx = sorted_classes.index(rare_class)
# 尝试向前或向后查找有效类别
for direction in [-1, 1]:
new_idx = idx + direction
if 0 <= new_idx < len(sorted_classes) and sorted_classes[new_idx] not in rare_classes:
class_mapping[rare_class] = sorted_classes[new_idx]
break
# 应用映射
if class_mapping:
print(f"类别合并映射: {class_mapping}")
df[target_variable] = df[target_variable].replace(class_mapping)
return df
def evaluate_model(model, test_data, target_variable, file_name):
"""
综合评估贝叶斯网络模型:预测性能 + 拟合度指标
:param model: 训练好的贝叶斯网络(pgmpy的BayesianNetwork对象)
:param test_data: 预处理后的测试数据集(pd.DataFrame)
:param target_variable: 目标预测变量(如"HeartDisease"、"stroke")
:param file_name: 数据集名称(用于日志输出)
:return: 评估结果字典
"""
results = {}
if model is None or test_data.empty:
print(f"[{file_name}] 模型或测试数据为空,跳过评估")
return results
# ----------------------
# 1. 模型拟合度指标(BIC、对数似然)
# ----------------------
try:
bic_score = BicScore(test_data).score(model) # 越小越好(平衡拟合与复杂度)
ll_score = log_likelihood_score(model, test_data) # 越大越好(数据拟合度)
results.update({
"bic_score": bic_score,
"log_likelihood": ll_score
})
except Exception as e:
print(f"[{file_name}] 拟合度计算失败: {e}")
bic_score = ll_score = None
# ----------------------
# 2. 预测性能评估(分类任务)
# ----------------------
if target_variable not in model.nodes() or target_variable not in test_data.columns:
print(f"[{file_name}] 目标变量 {target_variable} 不在模型或测试数据中")
return results
# 分离特征和标签
X_test = test_data.drop(columns=[target_variable], errors='ignore')
y_test = test_data[target_variable].dropna() # 过滤标签缺失值
X_test = X_test.loc[y_test.index] # 保持索引一致
if len(y_test) < 10: # 避免样本过少导致无意义评估
print(f"[{file_name}] 有效样本不足(仅{len(y_test)}条),跳过预测评估")
return results
infer = VariableElimination(model)
y_pred = []
fail_count = 0
for idx, row in X_test.iterrows():
# 构建证据(忽略缺失值,贝叶斯网络支持边际化)
evidence = {}
for var in model.nodes():
if var == target_variable:
continue
val = row.get(var)
if not pd.isna(val):
evidence[var] = val
# 推理目标变量的后验分布
try:
query = infer.query(variables=[target_variable], evidence=evidence, show_progress=False)
pred_class = query.argmax()[0] # 取概率最大的类别
y_pred.append(pred_class)
except Exception as e:
fail_count += 1
y_pred.append(None) # 标记推理失败的样本
# 过滤无效预测结果
valid_mask = [p is not None for p in y_pred]
y_test_valid = y_test[valid_mask]
y_pred_valid = [p for p in y_pred if p is not None]
if len(y_test_valid) == 0:
print(f"[{file_name}] 推理失败样本过多({fail_count}条),无有效预测结果")
return results
# 计算分类指标
try:
accuracy = accuracy_score(y_test_valid, y_pred_valid)
cm = confusion_matrix(y_test_valid, y_pred_valid)
report = classification_report(y_test_valid, y_pred_valid, zero_division=0)
results.update({
"accuracy": accuracy,
"confusion_matrix": cm,
"classification_report": report
})
except Exception as e:
print(f"[{file_name}] 分类指标计算失败: {e}")
accuracy = cm = report = None
# ----------------------
# 3. 输出评估结果
# ----------------------
print("\n" + "=" * 40)
print(f"[{file_name}] 模型评估报告")
print("=" * 40)
if bic_score is not None:
print(f"• BIC评分: {bic_score:.2f}(越小越好,平衡复杂度与拟合度)")
print(f"• 对数似然: {ll_score:.2f}(越大越好,数据拟合度)")
if accuracy is not None:
print(f"• 预测准确率: {accuracy:.4f}")
print(f"• 混淆矩阵:\n{cm}")
print(f"• 分类报告:\n{report}")
print(f"• 有效样本数: {len(y_test_valid)} / {len(test_data)}")
print(f"• 推理失败数: {fail_count}")
return results
def print_model_cpds(model, file_name, top_n=3):
"""打印模型的前N个节点的CPD,验证是否有效"""
if model is None:
return
print(f"\n[{file_name}] 模型CPD验证:")
for i, node in enumerate(model.nodes()):
if i >= top_n:
break
cpd = model.get_cpds(node)
if cpd:
print(f"\n节点 {node} 的CPD:")
print(cpd)
else:
print(f"节点 {node} 无有效CPD!")
# 主函数 - 支持自定义配置
def main():
file_names = ['heart', 'stroke', 'cirrhosis']
print('Main already')
df_heart['HeartDisease'] = df_heart['HeartDisease'].astype('category')
df_stroke['stroke'] = df_stroke['stroke'].astype('category')
df_cirrhosis['Stage'] = df_cirrhosis['Stage'].astype('category')
target_mapping = {
'heart': 'HeartDisease',
'stroke': 'stroke',
'cirrhosis': 'Stage'
}
# 加载数据
datasets = [df_heart, df_stroke, df_cirrhosis]
processed_datasets = []
for data, name in zip(datasets, file_names):
target_var = target_mapping.get(name)
if not target_var or target_var not in data.columns:
processed_datasets.append(data)
continue
print(f"\n处理 {name} 数据集中的稀有类别(目标变量: {target_var})")
# 使用合并稀有类别的方法(推荐)
processed_data = merge_rare_classes(data, target_var, min_samples=5)
# 确保目标变量是分类类型
processed_data[target_var] = processed_data[target_var].astype('category')
processed_datasets.append(processed_data)
train_datasets = []
test_datasets = []
for data in processed_datasets:
train, test = train_test_split(data, test_size=0.2, random_state=42)
train_datasets.append(train)
test_datasets.append(test)
# 为每个数据集定义自定义配置
# disease_configs = [
# {
# "disease_name": "心脏病",
# "local_shared_vars": {"cholesterol": "Cholesterol_bin"},
# "included_features": [
# "Age_bin", "Sex", "ChestPainType", "RestingBP_bin", "Cholesterol_bin",
# "FastingBS", "ExerciseAngina", "Oldpeak_bin", "ST_Slope", "HeartDisease"
# ],
# "custom_edges": [
# # 核心直接风险因素
# ("ChestPainType", "HeartDisease"),
# ("ST_Slope", "HeartDisease"),
# ("Oldpeak_bin", "HeartDisease"),
# ("Cholesterol_bin", "HeartDisease"),
# ("RestingBP_bin", "HeartDisease"),
# ("ExerciseAngina", "HeartDisease"),
# ("Age_bin", "HeartDisease"),
#
# # 关键间接路径
# ("FastingBS", "Cholesterol_bin"), # 高血糖影响胆固醇
# ("Age_bin", "RestingBP_bin"), # 年龄影响血压
# ],
# "forbidden_edges": [
# ("Sex", "HeartDisease"), # 性别通过其他因素间接影响
# ("HeartDisease", "*"), # 避免反向因果
# ("FastingBS", "HeartDisease"), # 血糖主要通过胆固醇间接影响
# ]
# },
# {
# "disease_name": "中风",
# "local_shared_vars": {},
# "included_features": [
# "Age_bin", "Sex", "hypertension", "heart_disease",
# "glucose_bin", "bmi_bin", "smoking_status", "stroke"
# ],
# "custom_edges": [
# # 核心直接风险因素
# ("hypertension", "stroke"),
# ("heart_disease", "stroke"),
# ("Age_bin", "stroke"),
# ("glucose_bin", "stroke"),
#
# # 关键间接路径
# ("smoking_status", "heart_disease"), # 吸烟导致心脏病
# ("bmi_bin", "hypertension"), # 肥胖导致高血压
# ("bmi_bin", "glucose_bin"), # 肥胖影响血糖
# ],
# "forbidden_edges": [
# ("Sex", "stroke"), # 性别通过其他因素间接影响
# ("stroke", "*"), # 避免反向因果
# ("bmi_bin", "stroke"), # BMI主要通过高血压和血糖间接影响
# ]
# },
# {
# "disease_name": "肝硬化",
# "local_shared_vars": {"cholesterol": "Cholesterol_bin"},
# "included_features": [
# "Age_bin", "Sex", "Ascites", "Hepatomegaly", "Spiders",
# "Edema", "Bilirubin_bin", "Albumin_bin",
# "SGOT_bin", "Platelets_bin", "Prothrombin_bin", "Stage"
# ],
# "custom_edges": [
# # 肝功能指标→疾病阶段
# ("Bilirubin_bin", "Stage"), # 胆红素升高→肝硬化进展
# ("Albumin_bin", "Stage"), # 白蛋白降低→肝硬化进展
# ("Prothrombin_bin", "Stage"), # 凝血酶原时间延长→肝硬化
# ("Platelets_bin", "Stage"), # 血小板减少→肝硬化
#
# # 体征→疾病阶段
# ("Ascites", "Stage"), # 腹水→肝硬化晚期
# ("Hepatomegaly", "Stage"), # 肝肿大→肝硬化
#
# # 间接关联
# ("Age_bin", "Stage"), # 年龄→疾病进展
# ("Ascites", "Edema"), # 腹水→水肿
# ("Albumin_bin", "Ascites"), # 低白蛋白→腹水
# ],
# "forbidden_edges": [
# ("ID", "*"),
# ("Status", "*"),
# ("Drug", "Stage"),
# ("Sex", "Stage"), # 性别不直接影响肝硬化阶段
# ("Stage", "*"), # 避免反向因果
# ]
# }
# ]
disease_configs = [
{
"disease_name": "心脏病",
"local_shared_vars": {"cholesterol": "Cholesterol_bin"},
"included_features": [
"Age_bin", "ChestPainType", "ST_Slope", "RestingBP_bin", # 4个核心父节点
"HeartDisease" # 目标变量
],
"custom_edges": [
# 仅保留4个直接影响心脏病的核心父节点
("ChestPainType", "HeartDisease"), # 胸痛类型(最核心症状)
("ST_Slope", "HeartDisease"), # ST段斜率(诊断金标准)
("RestingBP_bin", "HeartDisease"), # 血压(直接风险)
("Age_bin", "HeartDisease") # 年龄(基础风险)
],
"forbidden_edges": [] # 禁用禁止边,避免误删核心关联
},
{
"disease_name": "中风",
"local_shared_vars": {},
"included_features": [
"hypertension", "glucose_bin", "Age_bin", "heart_disease", # 4个核心父节点
"stroke" # 目标变量
],
"custom_edges": [
# 仅保留4个直接影响中风的核心父节点
("hypertension", "stroke"), # 高血压(头号风险)
("glucose_bin", "stroke"), # 高血糖(独立风险)
("heart_disease", "stroke"), # 心脏病史(血栓风险)
("Age_bin", "stroke") # 年龄(累积风险)
],
"forbidden_edges": []
},
{
"disease_name": "肝硬化",
"local_shared_vars": {"cholesterol": "Cholesterol_bin"},
"included_features": [
"Bilirubin_bin", "Albumin_bin", "Age_bin", "Ascites", # 4个核心父节点
"Stage" # 目标变量
],
"custom_edges": [
# 仅保留4个直接影响肝硬化阶段的核心父节点
("Bilirubin_bin", "Stage"), # 胆红素(肝脏排泄功能)
("Albumin_bin", "Stage"), # 白蛋白(肝脏合成功能)
("Ascites", "Stage"), # 腹水(肝硬化晚期标志)
("Age_bin", "Stage") # 年龄(病程累积)
],
"forbidden_edges": []
}
]
print('Config already')
# 3. 构建模型(传递全局共享变量和局部共享变量)
models = []
for train_data, name, config in zip(train_datasets, file_names, disease_configs):
model = build_bayesian_network(
data=train_data,
file_name=name,
shared_variables=["age", "gender"],
local_shared_vars=config["local_shared_vars"],
included_features=config["included_features"],
custom_edges=config["custom_edges"],
forbidden_edges=config["forbidden_edges"]
)
print_model_cpds(model, name)
models.append(model)
# 4. 可视化网络
for model, name in zip(models, file_names):
visualize_network(model, name)
for model, test_data, name, config in zip(models, test_datasets, file_names, disease_configs):
# 根据疾病类型确定目标变量
target_mapping = {
'heart': 'HeartDisease',
'stroke': 'stroke',
'cirrhosis': 'Stage'
}
target_var = target_mapping.get(name)
if not target_var:
print(f"[{name}] 未知目标变量,跳过验证")
continue
# 执行评估
evaluate_model(model, test_data, target_var, name)
# # 5. 跨网络推理(传递共享变量映射,确保证据匹配)
# # 注意:证据中的变量名需与配置中的“逻辑名”一致,而非原始列名
# evidence = {
# "age": 60, # 全局共享变量(年龄)
# "gender": "男", # 全局共享变量(性别)
# "cholesterol": 240 # 局部共享变量(胆固醇,逻辑名)
# }
#
# cross_network_inference(
# models=models,
# file_names=file_names,
# global_shared_vars=["age", "gender"], # 全局共享变量
# local_shared_vars={"cholesterol": "Cholesterol"}, # 局部共享变量映射(逻辑名→实际列名)
# evidence=evidence
# )
if __name__ == "__main__":
main()
什么情况,你帮我看看,怎么解决