ES-范围+in查询


GET /idx_hippo_sku_statistics_index/sku_statistics_index/_search
{
 "from" : 0,
 "size" : 100,
     "query" : {
"bool" : {
  "must":[ 
      {
    "range" : {
        "salePrice" : {
          "from" : 1.9,
          "to" : 2.1,
          "include_lower": true,
           "include_upper": true
        }
       }
   }
   ,
        {
          "bool": {
              "must": {
                  "terms": {
                        "cityId": [136]
                        }
                    }
           }
        }
   
    ]
}
},
"sort": [
  {
    "salePrice": "asc"
   }
 
],
"_source": ["productName","salePrice"]
}

"range" : { "salePrice" : { "from" : 1.9, "to" : 2.1, "include_lower": true, "include_upper": true } }

表示范围查询,查询      1.9 <= salePrice <= 2.1 的

include_lower  为  ture   >=     false >

include_upper  为  ture  <=     false <

terms 表示单个查询;

C:\ProgramData\anaconda3\python.exe "C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py" 918条数据的年龄分箱分布: Age_bin 0 0 1 80 2 585 3 253 4 0 Name: count, dtype: int64 5110条数据的年龄分箱分布: Age_bin 0 966 1 1204 2 1564 3 1190 4 186 Name: count, dtype: int64 412条数据的年龄分箱分布: Age_bin 0 0 1 69 2 257 3 86 4 0 Name: count, dtype: int64 胆固醇分箱分布(标签0-1-2): Cholesterol_bin 0 318 1 237 2 363 Name: count, dtype: int64 胆固醇分箱分布(标签0-1-2): Cholesterol_bin 0 23 1 39 2 252 Name: count, dtype: int64 Sex ChestPainType FastingBS RestingECG ExerciseAngina ST_Slope HeartDisease RestingBP_bin MaxHR_bin Oldpeak_bin Age_bin Cholesterol_bin 0 M ATA 0 Normal N Up 0 1级高血压 高(85-100%) 无压低 2 2 1 F NAP 0 Normal N Flat 1 2级高血压 高(85-100%) 轻度压低 2 0 2 M ATA 0 ST N Up 0 升高前期 低(<60%) 无压低 1 2 3 F ASY 0 Normal Y Flat 1 1级高血压 中(60-85%) 中度压低 2 1 4 M NAP 0 Normal N Up 0 2级高血压 中(60-85%) 无压低 2 0 .. .. ... ... ... ... ... ... ... ... ... ... ... 913 M TA 0 Normal N Flat 1 正常 中(60-85%) 中度压低 2 2 914 M ASY 1 Normal N Flat 1 2级高血压 高(85-100%) 严重压低 3 0 915 M ASY 0 Normal Y Flat 1 升高前期 中(60-85%) 中度压低 2 0 916 F ATA 0 LVH N Flat 1 升高前期 极高(>100%) 无压低 2 1 917 M NAP 0 Normal N Up 0 1级高血压 高(85-100%) 无压低 1 0 [918 rows x 12 columns] ID N_Days Status Drug Sex Ascites Hepatomegaly Spiders Edema Bilirubin Albumin Stage Bilirubin_bin Albumin_bin Copper_bin Alk_Phos_bin SGOT_bin Tryglicerides_bin Platelets_bin Prothrombin_bin Age_bin Cholesterol_bin 0 1 400 D D-penicillamine F Y Y Y Y 14.5 2.60 4 显著升高 降低 升高 升高 升高 升高 正常 正常 2 2 1 2 4500 C D-penicillamine F N Y Y N 1.1 4.14 3 正常 正常 升高 升高 升高 正常 正常 正常 2 2 2 3 1012 D D-penicillamine M N N N S 1.4 3.48 4 轻度升高 降低 升高 升高 升高 正常 正常 正常 3 0 3 4 1925 D D-penicillamine F N Y Y S 1.8 2.54 4 轻度升高 降低 升高 升高 升高 正常 正常 正常 2 2 4 5 1504 CL Placebo F N Y Y N 3.4 3.53 3 显著升高 正常 升高 升高 升高 正常 减少 正常 1 2 .. ... ... ... ... .. ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... 407 414 681 D D-penicillamine F N Y N N 1.2 2.96 3 正常 降低 升高 升高 正常 升高 正常 正常 3 NaN 408 415 1103 C D-penicillamine F N Y N N 0.9 3.83 4 正常 正常 升高 升高 正常 升高 正常 正常 1 NaN 409 416 1055 C D-penicillamine F N Y N N 1.6 3.42 3 轻度升高 降低 升高 升高 正常 升高 减少 正常 2 NaN 410 417 691 C D-penicillamine F N Y N N 0.8 3.75 3 正常 正常 升高 升高 正常 升高 正常 正常 2 NaN 411 418 976 C D-penicillamine F N Y N N 0.7 3.29 4 正常 降低 升高 升高 正常 升高 正常 正常 2 NaN [412 rows x 22 columns] id Sex hypertension heart_disease ever_married work_type Residence_type smoking_status stroke glucose_bin Age_bin bmi_bin 0 9046 M 0 1 Yes Private Urban formerly smoked 1 糖尿病 3 肥胖 1 51676 F 0 0 Yes Self-employed Rural never smoked 1 糖尿病 3 偏瘦 2 31112 M 0 1 Yes Private Rural never smoked 1 偏高 4 肥胖 3 60182 F 0 0 Yes Private Urban smokes 1 糖尿病前期 2 肥胖 4 1665 F 1 0 Yes Self-employed Rural never smoked 1 糖尿病前期 3 正常 ... ... .. ... ... ... ... ... ... ... ... ... ... 5105 18234 F 1 0 Yes Private Urban never smoked 0 正常 4 NaN 5106 44873 F 0 0 Yes Self-employed Urban never smoked 0 偏高 4 肥胖 5107 19723 F 0 0 Yes Self-employed Rural never smoked 0 正常 1 肥胖 5108 37544 M 0 0 Yes Private Rural formerly smoked 0 糖尿病前期 2 超重 5109 44679 F 0 0 Yes Govt_job Urban Unknown 0 正常 2 超重 [5110 rows x 12 columns] 处理后的数据形状:(918, 12)(无NaN) 处理后的数据形状:(5110, 12)(无NaN) 处理后的数据形状:(412, 22)(无NaN) Main already 处理 heart 数据集中的稀有类别(目标变量: HeartDisease) 处理 stroke 数据集中的稀有类别(目标变量: stroke) 处理 cirrhosis 数据集中的稀有类别(目标变量: Stage) C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:271: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` df['Oldpeak_bin'] = df['Oldpeak_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:271: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead. df['Oldpeak_bin'] = df['Oldpeak_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:275: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` df['RestingBP_bin'] = df['RestingBP_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:275: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead. df['RestingBP_bin'] = df['RestingBP_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:279: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` df['ST_Slope'] = df['ST_Slope'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:283: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` df['ChestPainType'] = df['ChestPainType'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:290: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` df['glucose_bin'] = df['glucose_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:290: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead. df['glucose_bin'] = df['glucose_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:293: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` df['bmi_bin'] = df['bmi_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:293: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead. df['bmi_bin'] = df['bmi_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:296: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` df['smoking_status'] = df['smoking_status'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:303: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` df['Bilirubin_bin'] = df['Bilirubin_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:303: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead. df['Bilirubin_bin'] = df['Bilirubin_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:306: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` df['Albumin_bin'] = df['Albumin_bin'].replace( C:\Users\Diwith\Daily_Project\2023ES2\YataiBei_Bayesian Network.py:306: FutureWarning: The behavior of Series.replace (and DataFrame.replace) with CategoricalDtype is deprecated. In a future version, replace will only be used for cases that preserve the categories. To change the categories, use ser.cat.rename_categories instead. df['Albumin_bin'] = df['Albumin_bin'].replace( Config already 已为 heart 构建贝叶斯网络 [heart] 模型CPD验证: 节点 ChestPainType 的CPD: +------------------+----------+ | ChestPainType(0) | 0.542339 | +------------------+----------+ | ChestPainType(1) | 0.190188 | +------------------+----------+ | ChestPainType(2) | 0.221102 | +------------------+----------+ | ChestPainType(3) | 0.046371 | +------------------+----------+ 节点 HeartDisease 的CPD: +-----------------+-----+---------------------+ | Age_bin | ... | Age_bin(3) | +-----------------+-----+---------------------+ | ChestPainType | ... | ChestPainType(3) | +-----------------+-----+---------------------+ | RestingBP_bin | ... | RestingBP_bin(3) | +-----------------+-----+---------------------+ | ST_Slope | ... | ST_Slope(2) | +-----------------+-----+---------------------+ | HeartDisease(0) | ... | 0.6628959276018099 | +-----------------+-----+---------------------+ | HeartDisease(1) | ... | 0.33710407239819007 | +-----------------+-----+---------------------+ 节点 ST_Slope 的CPD: +-------------+-----------+ | ST_Slope(0) | 0.0689964 | +-------------+-----------+ | ST_Slope(1) | 0.513889 | +-------------+-----------+ | ST_Slope(2) | 0.417115 | +-------------+-----------+ 已为 stroke 构建贝叶斯网络 [stroke] 模型CPD验证: 节点 hypertension 的CPD: +-----------------+-----------+ | hypertension(0) | 0.905076 | +-----------------+-----------+ | hypertension(1) | 0.0949244 | +-----------------+-----------+ 节点 stroke 的CPD: +---------------+-----+------------------+ | Age_bin | ... | Age_bin(4) | +---------------+-----+------------------+ | glucose_bin | ... | glucose_bin(3) | +---------------+-----+------------------+ | heart_disease | ... | heart_disease(1) | +---------------+-----+------------------+ | hypertension | ... | hypertension(1) | +---------------+-----+------------------+ | stroke(0) | ... | 0.5 | +---------------+-----+------------------+ | stroke(1) | ... | 0.5 | +---------------+-----+------------------+ 节点 glucose_bin 的CPD: +----------------+-----------+ | glucose_bin(0) | 0.60798 | +----------------+-----------+ | glucose_bin(1) | 0.227794 | +----------------+-----------+ | glucose_bin(2) | 0.0777208 | +----------------+-----------+ | glucose_bin(3) | 0.0865056 | +----------------+-----------+ 已为 cirrhosis 构建贝叶斯网络 [cirrhosis] 模型CPD验证: 节点 Bilirubin_bin 的CPD: +------------------+----------+ | Bilirubin_bin(0) | 0.446411 | +------------------+----------+ | Bilirubin_bin(1) | 0.26647 | +------------------+----------+ | Bilirubin_bin(2) | 0.287119 | +------------------+----------+ 节点 Stage 的CPD: +---------------+-----+------------------+ | Age_bin | ... | Age_bin(3) | +---------------+-----+------------------+ | Albumin_bin | ... | Albumin_bin(1) | +---------------+-----+------------------+ | Ascites | ... | Ascites(Y) | +---------------+-----+------------------+ | Bilirubin_bin | ... | Bilirubin_bin(2) | +---------------+-----+------------------+ | Stage(1) | ... | 0.25 | +---------------+-----+------------------+ | Stage(2) | ... | 0.25 | +---------------+-----+------------------+ | Stage(3) | ... | 0.25 | +---------------+-----+------------------+ | Stage(4) | ... | 0.25 | +---------------+-----+------------------+ 节点 Albumin_bin 的CPD: +----------------+----------+ | Albumin_bin(0) | 0.477876 | +----------------+----------+ | Albumin_bin(1) | 0.522124 | +----------------+----------+ 已成功保存 heart 网络结构图 已成功保存 stroke 网络结构图 已成功保存 cirrhosis 网络结构图 [heart] 拟合度计算失败: Missing columns in data. Can't find values for the following variables: set() [heart] 推理失败样本过多(184条),无有效预测结果 [stroke] 拟合度计算失败: Missing columns in data. Can't find values for the following variables: set() [stroke] 推理失败样本过多(1022条),无有效预测结果 [cirrhosis] 拟合度计算失败: Missing columns in data. Can't find values for the following variables: set() [cirrhosis] 推理失败样本过多(83条),无有效预测结果 进程已结束,退出代码为 0 import pandas as pd import matplotlib.pyplot as plt import numpy as np from pgmpy.models import BayesianNetwork from pgmpy.estimators import MaximumLikelihoodEstimator from pgmpy.inference import VariableElimination import networkx as nx from sklearn.metrics import accuracy_score, confusion_matrix, classification_report from pgmpy.metrics import log_likelihood_score from pgmpy.estimators import BicScore from sklearn.model_selection import train_test_split from pgmpy.estimators import BayesianEstimator from sympy.solvers.diophantine.diophantine import equivalent plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False pd.set_option('display.max_columns', 500) pd.set_option('display.width', 1000) df_heart = pd.read_csv('heart_se.csv') df_stroke = pd.read_csv('stroke_se.csv') df_cirrhosis = pd.read_csv('cirrhosis_se.csv') df_stroke.columns = ['id', 'Sex', 'Age', 'hypertension', 'heart_disease', 'ever_married', 'work_type', 'Residence_type', 'avg_glucose_level', 'bmi', 'smoking_status', 'stroke'] df_stroke['Sex'] = df_stroke['Sex'].apply(lambda x: 'M' if x == 'Male' else 'F') df_cirrhosis['Age'] = (df_cirrhosis['Age'] / 365).astype(int) df_stroke['Age'] = df_stroke['Age'].astype(int) def preprocess_cirrhosis(df): """处理肝硬化数据集中的高基数特征,转为低基数类别变量""" # 1. 胆固醇(Cholesterol):医学标准分箱(mg/dL) # 正常:<200;边缘升高:200-239;升高:≥240(参考临床标准) # df['Cholesterol_bin'] = pd.cut( # df['Cholesterol'], # bins=[0, 200, 240, 1000], # labels=['正常', '临界高值', '高胆固醇血症'] # ) # 2. 胆红素(Bilirubin):反映肝脏排泄功能(mg/dL) # 正常:<1.2;轻度升高:1.2-3.0;显著升高:>3.0 df['Bilirubin_bin'] = pd.cut( df['Bilirubin'], bins=[-np.inf, 1.2, 3.0, np.inf], labels=['正常', '轻度升高', '显著升高'] ) # 3. 白蛋白(Albumin):肝脏合成功能指标(g/dL) # 正常:3.5-5.0;降低:<3.5(肝硬化典型表现) df['Albumin_bin'] = pd.cut( df['Albumin'], bins=[-np.inf, 3.5, 5.0, np.inf], labels=['降低', '正常', '升高'] # 升高临床意义较小,合并为一类 ) # 4. 铜(Copper):尿铜排泄(μg/天) # 正常:<50;升高:≥50(肝豆状核变性指标) df['Copper_bin'] = pd.cut( df['Copper'], bins=[-np.inf, 50, np.inf], labels=['正常', '升高'] ) # 5. 碱性磷酸酶(Alk_Phos):肝胆疾病指标(单位/升) # 正常范围:40-150;升高:>150 df['Alk_Phos_bin'] = pd.cut( df['Alk_Phos'], bins=[-np.inf, 150, np.inf], labels=['正常', '升高'] ) # 6. 转氨酶(SGOT):肝细胞损伤指标(单位/毫升) # 正常:<40;升高:≥40 df['SGOT_bin'] = pd.cut( df['SGOT'], bins=[-np.inf, 40, np.inf], labels=['正常', '升高'] ) # 7. 甘油三酯(Tryglicerides):血脂指标(mg/dL) # 正常:<150;升高:≥150 df['Tryglicerides_bin'] = pd.cut( df['Tryglicerides'], bins=[-np.inf, 150, np.inf], labels=['正常', '升高'] ) # 8. 血小板(Platelets):肝硬化脾功能亢进指标(&times;10^9/L) # 正常:150-450;减少:<150 df['Platelets_bin'] = pd.cut( df['Platelets'], bins=[-np.inf, 150, np.inf], labels=['减少', '正常'] ) # 9. 凝血酶原时间(Prothrombin):肝脏合成功能(秒) # 正常:11-13;延长:>13 df['Prothrombin_bin'] = pd.cut( df['Prothrombin'], bins=[-np.inf, 13, np.inf], labels=['正常', '延长'] ) # 移除原始高基数变量,保留离散化后的变量 df = df.drop( columns=['Copper', 'Alk_Phos', 'SGOT', 'Tryglicerides', 'Platelets', 'Prothrombin'], errors='ignore' ) return df df_stroke['glucose_bin'] = pd.cut( df_stroke['avg_glucose_level'], bins=[0, 100, 140, 200, 1000], labels=['正常', '偏高', '糖尿病前期', '糖尿病'] ) # 预处理cirrhosis(含离散化+类型转换) df_cirrhosis = preprocess_cirrhosis(df_cirrhosis) def preprocess_heart(df): """处理心脏病数据集中的高基数特征,转为低基数类别变量""" # 2. 静息血压(RestingBP):按高血压指南分箱 # 正常:<120;升高前期:120-129;1级高血压:130-139;2级高血压:≥140 df['RestingBP_bin'] = pd.cut( df['RestingBP'], bins=[0, 120, 130, 140, 300], labels=['正常', '升高前期', '1级高血压', '2级高血压'] ) # 3. 胆固醇(Cholesterol):按血脂异常标准分箱 # 正常:<200;临界高值:200-239;高胆固醇血症:≥240 # df['Cholesterol_bin'] = pd.cut( # df['Cholesterol'], # bins=[0, 200, 240, 1000], # labels=['正常', '临界高值', '高胆固醇血症'] # ) # 4. 最大心率(MaxHR):按年龄预测最大心率百分比分箱 # 计算预测最大心率 = 220 - 年龄 df['MaxHR_pct'] = df['MaxHR'] / (220 - df['Age']) df['MaxHR_bin'] = pd.cut( df['MaxHR_pct'], bins=[0, 0.6, 0.85, 1.0, 2.0], labels=['低(<60%)', '中(60-85%)', '高(85-100%)', '极高(>100%)'] ) df.drop(columns=['MaxHR_pct'], inplace=True) # 删除临时计算列 # 5. ST段压低(Oldpeak):反映心肌缺血程度 # 无压低:≤0;轻度压低:0-1;中度压低:1-2;严重压低:>2 df['Oldpeak_bin'] = pd.cut( df['Oldpeak'], bins=[-0.1, 0, 1, 2, 10.0], labels=['无压低', '轻度压低', '中度压低', '严重压低'] ) # 移除原始高基数变量,保留离散化后的变量 df = df.drop( columns=['RestingBP','MaxHR', 'Oldpeak'], errors='ignore' ) return df df_heart = preprocess_heart(df_heart) # 将目标变量转为分类类型(确保pgmpy能处理) df_heart['HeartDisease'] = df_heart['HeartDisease'].astype('category') df_stroke['stroke'] = df_stroke['stroke'].astype('category') df_cirrhosis['Stage'] = df_cirrhosis['Stage'].astype('category') def discretize_age_uniform(df, age_col='Age'): """ 为三个数据集统一年龄分箱,用0-4数字标签,覆盖0.05~82岁所有范围 分箱逻辑:按20年一个区间,兼顾医学年龄分组和数据覆盖 标签:0-4(数字越小表示年龄越小) """ df['Age_bin'] = pd.cut( df[age_col], bins=[-0.1, 19, 39, 59, 79, 120], # 覆盖0.05~82岁所有数据 labels=[0, 1, 2, 3, 4], # 数字标签(0:儿童青少年;1:青年;2:中年;3:老年;4:高龄) include_lowest=True # 包含最小值(0.05岁) ) # 检查分箱后的数据分布(确保无空区间) print(f"{df.shape[0]}条数据的年龄分箱分布:") print(df['Age_bin'].value_counts().sort_index()) # 按0-4排序输出 return df df_heart = discretize_age_uniform(df_heart) # 心脏病(28~77岁) df_stroke = discretize_age_uniform(df_stroke) # 中风(0.05~82岁) df_cirrhosis = discretize_age_uniform(df_cirrhosis) # 肝硬化(26~78岁) # 分箱后删除原始Age列,保留分箱列(数字标签) df_heart.drop(columns=['Age'], inplace=True) df_stroke.drop(columns=['Age'], inplace=True) df_cirrhosis.drop(columns=['Age'], inplace=True) def discretize_cholesterol(df, cholesterol_col='Cholesterol'): """ 胆固醇分箱:统一标准,标签0-1-2,覆盖两个数据集的胆固醇范围 分箱逻辑(医学标准): - 0: 正常(<200 mg/dL) - 1: 临界高值(200-239 mg/dL) - 2: 升高(≥240 mg/dL) """ if cholesterol_col not in df.columns: print("数据集中无胆固醇列,跳过分箱") return df df['Cholesterol_bin'] = pd.cut( df[cholesterol_col], bins=[-1, 199, 239, 10000], # 覆盖两个数据集的胆固醇范围(最高达3e6,设10000足够) labels=[0, 1, 2], # 数字标签(0-正常,1-临界,2-升高) include_lowest=True ) # 验证分箱分布(确保无空箱) print(f"胆固醇分箱分布(标签0-1-2):") print(df['Cholesterol_bin'].value_counts().sort_index()) # 按0-1-2排序 # 删除原始胆固醇列,保留分箱列 df.drop(columns=[cholesterol_col], inplace=True, errors='ignore') return df df_heart = discretize_cholesterol(df_heart, cholesterol_col='Cholesterol') df_cirrhosis = discretize_cholesterol(df_cirrhosis, cholesterol_col='Cholesterol') def discretize_stroke_continuous(df): # bmi分箱(按WHO标准) df['bmi_bin'] = pd.cut( df['bmi'], bins=[-1, 18.5, 24, 28, 100], labels=['偏瘦', '正常', '超重', '肥胖'] ) # 血糖分箱(复用之前的逻辑) df['glucose_bin'] = pd.cut( df['avg_glucose_level'], bins=[0, 100, 140, 200, 1000], labels=['正常', '偏高', '糖尿病前期', '糖尿病'] ) df.drop(columns=['bmi', 'avg_glucose_level'], inplace=True) return df df_stroke = discretize_stroke_continuous(df_stroke) # inner_join = pd.merge(df_heart, df_stroke, on=['Age']) print(df_heart) print(df_cirrhosis) print(df_stroke) # 在所有预处理后添加强制去NaN def drop_or_fill_nan(df): # 先尝试填充分类变量的NaN(用众数) for col in df.columns: if df[col].isna().any(): if df[col].dtype == 'category' or df[col].dtype == 'object': df[col] = df[col].fillna(df[col].mode()[0]) else: df[col] = df[col].fillna(df[col].median()) # 若仍有NaN,直接删除(确保数据干净) df = df.dropna() print(f"处理后的数据形状:{df.shape}(无NaN)") return df df_heart = drop_or_fill_nan(df_heart) df_stroke = drop_or_fill_nan(df_stroke) df_cirrhosis = drop_or_fill_nan(df_cirrhosis) # 心脏病:替换Oldpeak_bin、RestingBP_bin等中文标签 def replace_heart_labels(df): # Oldpeak_bin:0-3(无→轻→中→重) df['Oldpeak_bin'] = df['Oldpeak_bin'].replace( {'无压低':0, '轻度压低':1, '中度压低':2, '严重压低':3} ).astype(int) # RestingBP_bin:0-3(正常→升高前期→1级→2级) df['RestingBP_bin'] = df['RestingBP_bin'].replace( {'正常':0, '升高前期':1, '1级高血压':2, '2级高血压':3} ).astype(int) # ST_Slope:0-2(Down→Flat→Up) df['ST_Slope'] = df['ST_Slope'].replace( {'Down':0, 'Flat':1, 'Up':2} ).astype(int) # ChestPainType:0-3(ASY→ATA→NAP→TA) df['ChestPainType'] = df['ChestPainType'].replace( {'ASY':0, 'ATA':1, 'NAP':2, 'TA':3} ).astype(int) return df # 中风:替换glucose_bin、bmi_bin等 def replace_stroke_labels(df): df['glucose_bin'] = df['glucose_bin'].replace( {'正常':0, '偏高':1, '糖尿病前期':2, '糖尿病':3} ).astype(int) df['bmi_bin'] = df['bmi_bin'].replace( {'偏瘦':0, '正常':1, '超重':2, '肥胖':3} ).astype(int) df['smoking_status'] = df['smoking_status'].replace( {'never smoked':0, 'formerly smoked':1, 'smokes':2, 'Unknown':3} ).astype(int) return df # 肝硬化:替换Bilirubin_bin等 def replace_cirrhosis_labels(df): df['Bilirubin_bin'] = df['Bilirubin_bin'].replace( {'正常':0, '轻度升高':1, '显著升高':2} ).astype(int) df['Albumin_bin'] = df['Albumin_bin'].replace( {'降低':0, '正常':1, '升高':2} ).astype(int) # 其他标签类似替换为0-1-2 return df # 应用替换 df_heart = replace_heart_labels(df_heart) df_stroke = replace_stroke_labels(df_stroke) df_cirrhosis = replace_cirrhosis_labels(df_cirrhosis) #---------------------------------------------------------------------------------------------------- def build_bayesian_network(data, file_name, shared_variables=['age', 'gender'], # 保持参数名一致 local_shared_vars=None, # 新增局部共享变量参数 included_features=None, custom_edges=None, forbidden_edges=None): """为单个数据集构建贝叶斯网络模型,支持多个共享变量和自定义先验知识""" if data is None or data.empty: print(f"无有效数据,无法为 {file_name} 构建模型") return None # 将共享变量名统一转为小写(便于匹配) shared_variables = [var.lower() for var in shared_variables] # 如果指定了要包含的特征,则筛选数据 if included_features: # 确保所有共享变量被包含 for var in shared_variables: # 查找数据中匹配的列名(不区分大小写) matched_col = next((col for col in data.columns if col.lower() == var), None) if matched_col and matched_col not in included_features: included_features.append(matched_col) data = data[included_features] # 定义网络结构 edges = [] # 为每个共享变量添加边 # for shared_var in shared_variables: # # 查找数据中匹配的列名 # matched_col = next((col for col in data.columns if col.lower() == shared_var), None) # if matched_col: # for feature in data.columns: # if feature != matched_col: # edges.append((matched_col, feature)) # 共享变量→其他特征 # 原代码中共享变量生成边的部分修改为: for shared_var in shared_variables: matched_col = next((col for col in data.columns if col.lower() == shared_var), None) if matched_col: # 目标变量匹配时忽略大小写 target_vars = [col for col in data.columns if col.lower() in ['stroke', 'heartdisease']] # 修正为小写匹配 for feature in target_vars: if feature != matched_col and (matched_col, feature) not in edges: edges.append((matched_col, feature)) # 强制生成共享变量→目标变量的边 # 添加自定义边(领域知识) if custom_edges: edges.extend(custom_edges) # 移除禁止的边(领域知识) if forbidden_edges: edges = [edge for edge in edges if edge not in forbidden_edges] # 创建贝叶斯网络 model = BayesianNetwork(edges) # 使用最大似然估计器学习参数 model.fit(data, estimator=BayesianEstimator, prior_type = "BDeu", equivalent_sample_size = 10) print(f"已为 {file_name} 构建贝叶斯网络") return model # 可视化贝叶斯网络 def visualize_network(model, file_name): """ 可视化贝叶斯网络结构,增强鲁棒性处理各种布局异常 """ if model is None: print(f"警告: {file_name} 模型为空,无法可视化") return # 创建图的副本并转为无向图(简化布局计算) G = model.to_undirected() if hasattr(model, 'to_undirected') else model.copy() # 严格过滤无效边 valid_edges = [] for u, v in G.edges(): if u != v: # 排除自循环边 valid_edges.append((u, v)) # 清除无效边并检查是否还有剩余边 G.remove_edges_from(list(G.edges())) G.add_edges_from(valid_edges) if not G.edges(): print(f"错误: {file_name} 网络中没有有效边(可能所有边都是自循环),无法可视化") return # 尝试多种布局算法,按优先级选择 try: # 优先使用力导向布局(适合稀疏图) pos = nx.spring_layout(G, k=0.3, iterations=100, seed=42) except: try: # 备选:Kawai布局(适合密集图) pos = nx.kamada_kawai_layout(G, seed=42) except: try: # 备选:圆形布局 pos = nx.circular_layout(G) except: # 最后的备选:随机布局 pos = nx.random_layout(G, seed=42) print(f"警告: {file_name} 网络使用随机布局,可能效果不佳") # 检查节点位置是否有重叠(距离小于阈值) MIN_DISTANCE = 0.01 positions = list(pos.values()) has_overlap = False for i in range(len(positions)): for j in range(i + 1, len(positions)): dist = np.sqrt(((positions[i][0] - positions[j][0]) ** 2) + ((positions[i][1] - positions[j][1]) ** 2)) if dist < MIN_DISTANCE: has_overlap = True break if has_overlap: break # 如果有重叠,添加微小扰动 if has_overlap: print(f"警告: {file_name} 网络节点位置存在重叠,添加随机扰动") for node in pos: pos[node] = ( pos[node][0] + np.random.uniform(-0.02, 0.02), pos[node][1] + np.random.uniform(-0.02, 0.02) ) # 创建图形 plt.figure(figsize=(14, 12)) # 增大图形尺寸 # 绘制节点 nx.draw_networkx_nodes( G, pos, node_size=2800, # 增大节点尺寸 node_color='skyblue', alpha=0.8, edgecolors='black', # 添加节点边框 linewidths=1.0 ) # 绘制边(分批次处理,捕获异常) edges_to_draw = list(G.edges()) successful_edges = [] for edge in edges_to_draw: try: nx.draw_networkx_edges( G, pos, edgelist=[edge], arrows=True, arrowsize=20, width=1.5, alpha=0.7, edge_color='gray' ) successful_edges.append(edge) except Exception as e: print(f"警告: 无法绘制边 {edge}: {e}") if len(successful_edges) == 0: print(f"错误: {file_name} 网络中没有成功绘制的边") plt.close() return # 绘制节点标签 nx.draw_networkx_labels( G, pos, font_size=10, font_weight='bold', font_family='SimHei', # 确保中文显示 horizontalalignment='center', verticalalignment='center' ) # 设置标题和保存图形 plt.title(f'{file_name} 的贝叶斯网络结构', fontsize=14) plt.axis('off') # 关闭坐标轴 plt.tight_layout() # 调整布局 try: plt.savefig(f'{file_name}_bayesian_network.png', dpi=300, bbox_inches='tight') print(f"已成功保存 {file_name} 网络结构图") except Exception as e: print(f"保存图像时出错: {e}") plt.close() # 关闭图形以释放内存 # 执行跨网络推理 def cross_network_inference(models, file_names, global_shared_vars=['age', 'gender'], local_shared_vars=None, evidence=None): """基于共享变量进行跨网络推理,支持全局和局部共享变量""" results = {} if evidence is None: evidence = {'age': 50, 'gender': 'Male'} # 默认全局证据值 for model, file_name in zip(models, file_names): if model is None: continue # 映射全局证据中的变量名到模型中的实际列名 model_evidence = {} for var, value in evidence.items(): if var.lower() in [v.lower() for v in global_shared_vars]: matched_col = next((col for col in model.nodes() if col.lower() == var.lower()), None) if matched_col: model_evidence[matched_col] = value # 添加局部共享变量的证据(如果有) if local_shared_vars: for local_var_name, local_var_col in local_shared_vars.items(): if local_var_name in evidence and local_var_col in model.nodes(): model_evidence[local_var_col] = evidence[local_var_name] infer = VariableElimination(model) # 对每个模型中的关键变量进行推理 key_variables = [var for var in model.nodes() if var not in model_evidence] for var in key_variables: try: result = infer.query(variables=[var], evidence=model_evidence) results[f"{file_name}_{var}"] = result print(f"\n{file_name} 中,已知 {model_evidence} 时,{var} 的概率分布:") print(result) except Exception as e: print(f"\n{file_name} 推理 {var} 时错误: {e}") return results def merge_rare_classes(df, target_variable, min_samples=2): """ 合并样本数少于阈值的稀有类别 :param df: 数据集 :param target_variable: 目标变量列名 :param min_samples: 最小样本数阈值 :return: 处理后的数据集 """ if target_variable not in df.columns: return df # 统计每个类别的样本数 class_counts = df[target_variable].value_counts() rare_classes = class_counts[class_counts < min_samples].index.tolist() if not rare_classes: return df # 没有稀有类别,直接返回 print(f"检测到稀有类别: {rare_classes},样本数阈值: {min_samples}") # 创建映射:稀有类别→合并到的目标类别 # 策略:合并到最接近的类别(基于类别名称的数值或字母顺序) sorted_classes = sorted(df[target_variable].unique()) class_mapping = {} for rare_class in rare_classes: # 找到最接近的有效类别 idx = sorted_classes.index(rare_class) # 尝试向前或向后查找有效类别 for direction in [-1, 1]: new_idx = idx + direction if 0 <= new_idx < len(sorted_classes) and sorted_classes[new_idx] not in rare_classes: class_mapping[rare_class] = sorted_classes[new_idx] break # 应用映射 if class_mapping: print(f"类别合并映射: {class_mapping}") df[target_variable] = df[target_variable].replace(class_mapping) return df def evaluate_model(model, test_data, target_variable, file_name): """ 综合评估贝叶斯网络模型:预测性能 + 拟合度指标 :param model: 训练好的贝叶斯网络(pgmpy的BayesianNetwork对象) :param test_data: 预处理后的测试数据集(pd.DataFrame) :param target_variable: 目标预测变量(如"HeartDisease"、"stroke") :param file_name: 数据集名称(用于日志输出) :return: 评估结果字典 """ results = {} if model is None or test_data.empty: print(f"[{file_name}] 模型或测试数据为空,跳过评估") return results # ---------------------- # 1. 模型拟合度指标(BIC、对数似然) # ---------------------- try: bic_score = BicScore(test_data).score(model) # 越小越好(平衡拟合与复杂度) ll_score = log_likelihood_score(model, test_data) # 越大越好(数据拟合度) results.update({ "bic_score": bic_score, "log_likelihood": ll_score }) except Exception as e: print(f"[{file_name}] 拟合度计算失败: {e}") bic_score = ll_score = None # ---------------------- # 2. 预测性能评估(分类任务) # ---------------------- if target_variable not in model.nodes() or target_variable not in test_data.columns: print(f"[{file_name}] 目标变量 {target_variable} 不在模型或测试数据中") return results # 分离特征和标签 X_test = test_data.drop(columns=[target_variable], errors='ignore') y_test = test_data[target_variable].dropna() # 过滤标签缺失值 X_test = X_test.loc[y_test.index] # 保持索引一致 if len(y_test) < 10: # 避免样本过少导致无意义评估 print(f"[{file_name}] 有效样本不足(仅{len(y_test)}条),跳过预测评估") return results infer = VariableElimination(model) y_pred = [] fail_count = 0 for idx, row in X_test.iterrows(): # 构建证据(忽略缺失值,贝叶斯网络支持边际化) evidence = {} for var in model.nodes(): if var == target_variable: continue val = row.get(var) if not pd.isna(val): evidence[var] = val # 推理目标变量的后验分布 try: query = infer.query(variables=[target_variable], evidence=evidence, show_progress=False) pred_class = query.argmax()[0] # 取概率最大的类别 y_pred.append(pred_class) except Exception as e: fail_count += 1 y_pred.append(None) # 标记推理失败的样本 # 过滤无效预测结果 valid_mask = [p is not None for p in y_pred] y_test_valid = y_test[valid_mask] y_pred_valid = [p for p in y_pred if p is not None] if len(y_test_valid) == 0: print(f"[{file_name}] 推理失败样本过多({fail_count}条),无有效预测结果") return results # 计算分类指标 try: accuracy = accuracy_score(y_test_valid, y_pred_valid) cm = confusion_matrix(y_test_valid, y_pred_valid) report = classification_report(y_test_valid, y_pred_valid, zero_division=0) results.update({ "accuracy": accuracy, "confusion_matrix": cm, "classification_report": report }) except Exception as e: print(f"[{file_name}] 分类指标计算失败: {e}") accuracy = cm = report = None # ---------------------- # 3. 输出评估结果 # ---------------------- print("\n" + "=" * 40) print(f"[{file_name}] 模型评估报告") print("=" * 40) if bic_score is not None: print(f"• BIC评分: {bic_score:.2f}(越小越好,平衡复杂度与拟合度)") print(f"• 对数似然: {ll_score:.2f}(越大越好,数据拟合度)") if accuracy is not None: print(f"• 预测准确率: {accuracy:.4f}") print(f"• 混淆矩阵:\n{cm}") print(f"• 分类报告:\n{report}") print(f"• 有效样本数: {len(y_test_valid)} / {len(test_data)}") print(f"• 推理失败数: {fail_count}") return results def print_model_cpds(model, file_name, top_n=3): """打印模型的前N个节点的CPD,验证是否有效""" if model is None: return print(f"\n[{file_name}] 模型CPD验证:") for i, node in enumerate(model.nodes()): if i >= top_n: break cpd = model.get_cpds(node) if cpd: print(f"\n节点 {node} 的CPD:") print(cpd) else: print(f"节点 {node} 无有效CPD!") # 主函数 - 支持自定义配置 def main(): file_names = ['heart', 'stroke', 'cirrhosis'] print('Main already') df_heart['HeartDisease'] = df_heart['HeartDisease'].astype('category') df_stroke['stroke'] = df_stroke['stroke'].astype('category') df_cirrhosis['Stage'] = df_cirrhosis['Stage'].astype('category') target_mapping = { 'heart': 'HeartDisease', 'stroke': 'stroke', 'cirrhosis': 'Stage' } # 加载数据 datasets = [df_heart, df_stroke, df_cirrhosis] processed_datasets = [] for data, name in zip(datasets, file_names): target_var = target_mapping.get(name) if not target_var or target_var not in data.columns: processed_datasets.append(data) continue print(f"\n处理 {name} 数据集中的稀有类别(目标变量: {target_var})") # 使用合并稀有类别的方法(推荐) processed_data = merge_rare_classes(data, target_var, min_samples=5) # 确保目标变量是分类类型 processed_data[target_var] = processed_data[target_var].astype('category') processed_datasets.append(processed_data) train_datasets = [] test_datasets = [] for data in processed_datasets: train, test = train_test_split(data, test_size=0.2, random_state=42) train_datasets.append(train) test_datasets.append(test) # 为每个数据集定义自定义配置 # disease_configs = [ # { # "disease_name": "心脏病", # "local_shared_vars": {"cholesterol": "Cholesterol_bin"}, # "included_features": [ # "Age_bin", "Sex", "ChestPainType", "RestingBP_bin", "Cholesterol_bin", # "FastingBS", "ExerciseAngina", "Oldpeak_bin", "ST_Slope", "HeartDisease" # ], # "custom_edges": [ # # 核心直接风险因素 # ("ChestPainType", "HeartDisease"), # ("ST_Slope", "HeartDisease"), # ("Oldpeak_bin", "HeartDisease"), # ("Cholesterol_bin", "HeartDisease"), # ("RestingBP_bin", "HeartDisease"), # ("ExerciseAngina", "HeartDisease"), # ("Age_bin", "HeartDisease"), # # # 关键间接路径 # ("FastingBS", "Cholesterol_bin"), # 高血糖影响胆固醇 # ("Age_bin", "RestingBP_bin"), # 年龄影响血压 # ], # "forbidden_edges": [ # ("Sex", "HeartDisease"), # 性别通过其他因素间接影响 # ("HeartDisease", "*"), # 避免反向因果 # ("FastingBS", "HeartDisease"), # 血糖主要通过胆固醇间接影响 # ] # }, # { # "disease_name": "中风", # "local_shared_vars": {}, # "included_features": [ # "Age_bin", "Sex", "hypertension", "heart_disease", # "glucose_bin", "bmi_bin", "smoking_status", "stroke" # ], # "custom_edges": [ # # 核心直接风险因素 # ("hypertension", "stroke"), # ("heart_disease", "stroke"), # ("Age_bin", "stroke"), # ("glucose_bin", "stroke"), # # # 关键间接路径 # ("smoking_status", "heart_disease"), # 吸烟导致心脏病 # ("bmi_bin", "hypertension"), # 肥胖导致高血压 # ("bmi_bin", "glucose_bin"), # 肥胖影响血糖 # ], # "forbidden_edges": [ # ("Sex", "stroke"), # 性别通过其他因素间接影响 # ("stroke", "*"), # 避免反向因果 # ("bmi_bin", "stroke"), # BMI主要通过高血压和血糖间接影响 # ] # }, # { # "disease_name": "肝硬化", # "local_shared_vars": {"cholesterol": "Cholesterol_bin"}, # "included_features": [ # "Age_bin", "Sex", "Ascites", "Hepatomegaly", "Spiders", # "Edema", "Bilirubin_bin", "Albumin_bin", # "SGOT_bin", "Platelets_bin", "Prothrombin_bin", "Stage" # ], # "custom_edges": [ # # 肝功能指标→疾病阶段 # ("Bilirubin_bin", "Stage"), # 胆红素升高→肝硬化进展 # ("Albumin_bin", "Stage"), # 白蛋白降低→肝硬化进展 # ("Prothrombin_bin", "Stage"), # 凝血酶原时间延长→肝硬化 # ("Platelets_bin", "Stage"), # 血小板减少→肝硬化 # # # 体征→疾病阶段 # ("Ascites", "Stage"), # 腹水→肝硬化晚期 # ("Hepatomegaly", "Stage"), # 肝肿大→肝硬化 # # # 间接关联 # ("Age_bin", "Stage"), # 年龄→疾病进展 # ("Ascites", "Edema"), # 腹水→水肿 # ("Albumin_bin", "Ascites"), # 低白蛋白→腹水 # ], # "forbidden_edges": [ # ("ID", "*"), # ("Status", "*"), # ("Drug", "Stage"), # ("Sex", "Stage"), # 性别不直接影响肝硬化阶段 # ("Stage", "*"), # 避免反向因果 # ] # } # ] disease_configs = [ { "disease_name": "心脏病", "local_shared_vars": {"cholesterol": "Cholesterol_bin"}, "included_features": [ "Age_bin", "ChestPainType", "ST_Slope", "RestingBP_bin", # 4个核心父节点 "HeartDisease" # 目标变量 ], "custom_edges": [ # 仅保留4个直接影响心脏病的核心父节点 ("ChestPainType", "HeartDisease"), # 胸痛类型(最核心症状) ("ST_Slope", "HeartDisease"), # ST段斜率(诊断金标准) ("RestingBP_bin", "HeartDisease"), # 血压(直接风险) ("Age_bin", "HeartDisease") # 年龄(基础风险) ], "forbidden_edges": [] # 禁用禁止边,避免误删核心关联 }, { "disease_name": "中风", "local_shared_vars": {}, "included_features": [ "hypertension", "glucose_bin", "Age_bin", "heart_disease", # 4个核心父节点 "stroke" # 目标变量 ], "custom_edges": [ # 仅保留4个直接影响中风的核心父节点 ("hypertension", "stroke"), # 高血压(头号风险) ("glucose_bin", "stroke"), # 高血糖(独立风险) ("heart_disease", "stroke"), # 心脏病史(血栓风险) ("Age_bin", "stroke") # 年龄(累积风险) ], "forbidden_edges": [] }, { "disease_name": "肝硬化", "local_shared_vars": {"cholesterol": "Cholesterol_bin"}, "included_features": [ "Bilirubin_bin", "Albumin_bin", "Age_bin", "Ascites", # 4个核心父节点 "Stage" # 目标变量 ], "custom_edges": [ # 仅保留4个直接影响肝硬化阶段的核心父节点 ("Bilirubin_bin", "Stage"), # 胆红素(肝脏排泄功能) ("Albumin_bin", "Stage"), # 白蛋白(肝脏合成功能) ("Ascites", "Stage"), # 腹水(肝硬化晚期标志) ("Age_bin", "Stage") # 年龄(病程累积) ], "forbidden_edges": [] } ] print('Config already') # 3. 构建模型(传递全局共享变量和局部共享变量) models = [] for train_data, name, config in zip(train_datasets, file_names, disease_configs): model = build_bayesian_network( data=train_data, file_name=name, shared_variables=["age", "gender"], local_shared_vars=config["local_shared_vars"], included_features=config["included_features"], custom_edges=config["custom_edges"], forbidden_edges=config["forbidden_edges"] ) print_model_cpds(model, name) models.append(model) # 4. 可视化网络 for model, name in zip(models, file_names): visualize_network(model, name) for model, test_data, name, config in zip(models, test_datasets, file_names, disease_configs): # 根据疾病类型确定目标变量 target_mapping = { 'heart': 'HeartDisease', 'stroke': 'stroke', 'cirrhosis': 'Stage' } target_var = target_mapping.get(name) if not target_var: print(f"[{name}] 未知目标变量,跳过验证") continue # 执行评估 evaluate_model(model, test_data, target_var, name) # # 5. 跨网络推理(传递共享变量映射,确保证据匹配) # # 注意:证据中的变量名需与配置中的“逻辑名”一致,而非原始列名 # evidence = { # "age": 60, # 全局共享变量(年龄) # "gender": "男", # 全局共享变量(性别) # "cholesterol": 240 # 局部共享变量(胆固醇,逻辑名) # } # # cross_network_inference( # models=models, # file_names=file_names, # global_shared_vars=["age", "gender"], # 全局共享变量 # local_shared_vars={"cholesterol": "Cholesterol"}, # 局部共享变量映射(逻辑名→实际列名) # evidence=evidence # ) if __name__ == "__main__": main() 什么情况,你帮我看看,怎么解决
最新发布
07-15
贝叶斯网络模型在推理和评估阶段出现错误可能由多种因素导致,包括数据质量问题、模型结构不合理、参数估计不准确或算法实现问题等。以下是对这些问题的详细分析及对应的解决方法。 ### 数据质量问题 如果训练数据中存在缺失值、异常值或噪声,可能导致模型无法正确拟合数据,从而影响推理结果。例如,某些样本由于数据缺失而无法进行有效的概率计算,或者噪声干扰了条件概率分布的学习过程[^1]。 - **解决方法**:对数据进行预处理,包括缺失值填补(如使用插值法、均值填补等)、异常值检测与修正、以及去噪处理。确保数据质量是构建可靠贝叶斯网络的基础。 ### 模型结构不合理 贝叶斯网络的结构决定了变量之间的依赖关系。如果结构设计不合理,例如过度简化或复杂化,都会导致模型无法准确捕捉变量间的因果关系,进而影响推理效果[^1]。 - **解决方法**:采用合理的结构学习算法(如K2算法、PC算法)来优化网络结构。此外,结合领域知识手动调整网络结构也是一种有效手段。可以通过交叉验证等方式评估不同结构的性能,并选择最优结构。 ### 参数估计不准确 即使网络结构合理,若参数估计不准确(如条件概率表中的概率值偏差较大),也会导致推理失败。这种情况通常发生在数据量不足或数据分布不均衡的情况下[^1]。 - **解决方法**:增加训练数据量以提高参数估计的准确性;对于小样本问题,可以引入先验知识(如贝叶斯估计)或使用正则化方法减少过拟合风险。同时,检查条件概率表是否符合逻辑一致性,避免出现不合理概率值。 ### 推理算法问题 贝叶斯网络的推理算法(如变量消元法、信念传播算法)可能存在实现错误或不适用于特定类型的网络结构。例如,在存在环路的网络中直接应用信念传播可能会导致收敛失败[^1]。 - **解决方法**:确保推理算法的正确实现,必要时参考权威文献或开源库(如`pgmpy`、`BayesNet Toolbox`)。对于具有复杂结构的网络,考虑使用近似推理方法(如马尔可夫链蒙特卡洛采样)代替精确推理。 ### 评估指标选择不当 在模型评估阶段,如果使用的评估指标不适合当前任务(如分类任务误用回归指标),也可能导致评估失败或结果不可靠[^1]。 - **解决方法**:根据具体任务类型选择合适的评估指标。对于分类任务,常用指标包括准确率、召回率、F1分数等;对于概率预测任务,则可使用对数似然、Brier评分等。同时,结合混淆矩阵分析模型表现,识别潜在问题。 ### 实现代码示例 以下是一个简单的贝叶斯网络推理示例,使用Python的`pgmpy`库实现: ```python from pgmpy.models import BayesianModel from pgmpy.factors.discrete import TabularCPD from pgmpy.inference import VariableElimination # 定义网络结构 model = BayesianModel([('A', 'C'), ('B', 'C')]) # 定义条件概率分布 cpd_a = TabularCPD(variable='A', variable_card=2, values=[[0.6], [0.4]]) cpd_b = TabularCPD(variable='B', variable_card=2, values=[[0.7], [0.3]]) cpd_c = TabularCPD( variable='C', variable_card=2, values=[ [0.9, 0.6, 0.7, 0.1], # P(C=0 | A, B) [0.1, 0.4, 0.3, 0.9] # P(C=1 | A, B) ], evidence=['A', 'B'], evidence_card=[2, 2] ) # 添加CPD到模型 model.add_cpds(cpd_a, cpd_b, cpd_c) # 初始化推理引擎 infer = VariableElimination(model) # 进行查询 query_result = infer.query(variables=['C'], evidence={'A': 1, 'B': 0}) print(query_result) ``` ###
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值