import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import pydicom
import numpy as np
from sklearn.model_selection import train_test_split
from skimage.transform import resize
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_pil_image
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import ReduceLROnPlateau
import random
import matplotlib.pyplot as plt
from joblib import Memory
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from sklearn.calibration import calibration_curve
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression
import torchxrayvision as xrv # 导入torchxrayvision库
from sklearn.metrics import (
precision_score, recall_score, f1_score, confusion_matrix,
roc_auc_score, roc_curve, auc, precision_recall_curve,
brier_score_loss # 确保包含这个
)
from imblearn.over_sampling import SMOTE
# 设置可见的GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "2" # 修改为你要使用的GPU编号
# 硬件配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True # 启用CUDA优化
# 数据配置 - 统一图像尺寸参数
# 先声明为全局变量
global IMAGE_SIZE
IMAGE_SIZE = 224 # 统一使用224x224作为图像尺寸
BATCH_SIZE = 32
NUM_WORKERS = 4
# 创建缓存对象
memory = Memory(location='./cache', verbose=0)
# 使用torchxrayvision推荐的归一化参数(更适合X光片)
train_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE + 32, IMAGE_SIZE + 32)), # 先放大,为随机裁剪做准备
transforms.RandomCrop(IMAGE_SIZE),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=5),
transforms.ColorJitter(brightness=0, contrast=0.1, saturation=0.1, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5051], std=[0.2922]) # torchxrayvision的X光片归一化参数
])
val_test_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5051], std=[0.2922]) # torchxrayvision的X光片归一化参数
])
# -------------------- 数据预处理阶段 --------------------
def get_patient_name_by_accession(accession, df1):
"""通过放射编号获取病人姓名"""
# 找到匹配的放射编号
matching_rows = df1[df1.iloc[:, 1] == accession] # 第二列是放射编号
if len(matching_rows) > 0:
patient_name = matching_rows.iloc[0, 2] # 第三列是病人姓名
return str(patient_name).strip()
return None
def get_emphysema_label_by_name(patient_name, df2):
"""通过病人姓名获取肺气肿标签"""
# 找到匹配的姓名
matching_rows = df2[df2.iloc[:, 1] == patient_name] # 第二列是姓名
if len(matching_rows) > 0:
# 找到肺气肿列(假设列名包含"肺气肿"关键词)
emphysema_cols = [col for col in df2.columns if '肺气肿' in str(col)]
if emphysema_cols:
label = matching_rows[emphysema_cols[0]].values[0]
return label
return None
@memory.cache
def load_and_preprocess_data():
# 读取元数据
df1 = pd.read_excel(
r'/home/cfu/.virtualenvs/DR/tmp/pycharm_project_689/tmp/pycharm_project_715/DR已筛选2019-2023.xlsx',
sheet_name=0)
df2 = pd.read_excel(
r'/home/cfu/.virtualenvs/DR/tmp/pycharm_project_689/tmp/pycharm_project_715/副本2025.7.14-基础信息+肺功能-郭屿老师.xlsx',
sheet_name=0)
# 数据收集
X, y = [], []
dicom_dir = r'/home/cfu/.virtualenvs/DR/TEST/NEW.DR'
filter_reasons = {
'other_error': 0,
'no_pixel_data': 0,
'no_accession': 0,
'no_patient_name': 0,
'no_label': 0,
'invalid_label': 0
}
filtered_filenames = []
print("开始处理DICOM文件...")
print(f"DR已筛选2019-2023.xlsx 列名: {list(df1.columns)}")
print(f"2025.7.14-基础信息+肺功能-郭屿老师.xlsx 列名: {list(df2.columns)}")
for filename in tqdm(os.listdir(dicom_dir), desc="Processing DICOM"):
if filename.endswith('.dcm'):
try:
# 读取DICOM文件
dcm_path = os.path.join(dicom_dir, filename)
dcm = pydicom.dcmread(dcm_path)
if 'PixelData' not in dcm: # 显式检查像素数据
print(f"跳过无像素数据文件:{filename}")
filter_reasons['no_pixel_data'] += 1
filtered_filenames.append(filename)
continue
# 获取放射编号
if not hasattr(dcm, 'AccessionNumber') or not dcm.AccessionNumber:
print(f"跳过无放射编号文件:{filename}")
filter_reasons['no_accession'] += 1
filtered_filenames.append(filename)
continue
accession = str(dcm.AccessionNumber).strip()
print(f"处理文件 {filename}: 放射编号={accession}")
# 通过放射编号获取病人姓名
patient_name = get_patient_name_by_accession(accession, df1)
if not patient_name:
print(f"警告:放射编号 {accession} 未找到对应病人姓名,跳过该样本")
filter_reasons['no_patient_name'] += 1
filtered_filenames.append(filename)
continue
print(f" 病人姓名: {patient_name}")
# 通过病人姓名获取肺气肿标签
label = get_emphysema_label_by_name(patient_name, df2)
if label is None:
print(f"警告:病人 {patient_name} 未找到肺气肿标签,跳过该样本")
filter_reasons['no_label'] += 1
filtered_filenames.append(filename)
continue
# 处理标签
try:
label = int(label)
if label not in {0, 1}:
print(f"警告:病人 {patient_name} 的标签无效({label}),跳过该样本")
filter_reasons['invalid_label'] += 1
filtered_filenames.append(filename)
continue
except (ValueError, TypeError):
print(f"警告:病人 {patient_name} 的标签无法转换为整数({label}),跳过该样本")
filter_reasons['invalid_label'] += 1
filtered_filenames.append(filename)
continue
print(f" 肺气肿标签: {label}")
# 图像处理
img = dcm.pixel_array.astype(float)
# 裁剪图像为正方形
height, width = img.shape
if height > width:
start = (height - width) // 2
img = img[start:start + width, :]
else:
start = (width - height) // 2
img = img[:, start:start + height]
# 缩放为指定大小
img = resize(img, (IMAGE_SIZE, IMAGE_SIZE), order=3, anti_aliasing=True)
# 归一化
img = img / (img.max() + 1e-8) # 防止除零
X.append(img)
y.append(label)
except Exception as e:
print(f"处理文件 {filename} 时出错: {e}")
filter_reasons['other_error'] += 1
filtered_filenames.append(filename)
# 保存被过滤的文件列表(用于人工检查)
with open('filtered_files.txt', 'w') as f:
f.write('\n'.join(filtered_filenames))
print("\n数据过滤统计:")
for reason, count in filter_reasons.items():
print(f"{reason}: {count} 个文件")
# 转换为numpy数组
X = np.array(X, dtype=np.float32)
y = np.array(y, dtype=np.long) # 标签类型改为long
print(f"\n成功加载 {len(X)} 个样本")
print(f"标签分布: 0={np.sum(y == 0)}, 1={np.sum(y == 1)}")
return X, y
# -------------------- 数据集划分 --------------------
X, y = load_and_preprocess_data()
@memory.cache
def get_processed_accessions():
valid_filenames = []
dicom_dir = r'/home/cfu/.virtualenvs/DR/TEST/NEW.DR'
filtered_files = []
if os.path.exists('filtered_files.txt'):
with open('filtered_files.txt', 'r') as f:
filtered_files = [line.strip() for line in f.readlines()]
for filename in os.listdir(dicom_dir):
if filename.endswith('.dcm') and filename not in filtered_files:
valid_filenames.append(filename)
accessions = []
for filename in valid_filenames:
try:
dcm = pydicom.dcmread(os.path.join(dicom_dir, filename))
accessions.append(dcm.AccessionNumber)
except Exception as e:
print(f"读取文件 {filename} 的Accession时出错: {e}")
accessions.append(None)
return accessions
if len(X) == 0 or len(y) == 0:
print("错误:未加载到任何有效样本!")
exit(1)
accessions = get_processed_accessions()
filtered_files = []
if os.path.exists('filtered_files.txt'):
with open('filtered_files.txt', 'r') as f:
filtered_files = [line.strip() for line in f.readlines()]
if len(accessions) != len(X):
print(f"警告:Accession列表长度 ({len(accessions)}) 与样本数量 ({len(X)}) 不一致")
dicom_dir = r'/home/cfu/.virtualenvs/DR/TEST/NEW.DR'
valid_files_count = len([f for f in os.listdir(dicom_dir)
if f.endswith('.dcm') and f not in filtered_files])
print(f"有效文件数量: {valid_files_count}")
print(f"处理后样本数量: {len(X)}")
print(f"被过滤的文件数量: {len(filtered_files)}")
if len(accessions) > len(X):
print(f"截断Accession列表,从 {len(accessions)} 缩减到 {len(X)}")
accessions = accessions[:len(X)]
else:
print(f"补充Accession列表,从 {len(accessions)} 扩展到 {len(X)}")
accessions += [None] * (len(X) - len(accessions))
assert len(accessions) == len(X), "调整后仍然不匹配,请检查数据处理流程"
print("已自动调整使Accession列表与样本数量匹配")
# 划分数据集
X_train, X_test, y_train, y_test, acc_train, acc_test = train_test_split(
X, y, accessions, test_size=0.2, random_state=42, stratify=y
)
X_train, X_val, y_train, y_val, acc_train, acc_val = train_test_split(
X_train, y_train, acc_train, test_size=0.2, random_state=42, stratify=y_train
)
# 过采样处理类别不平衡
smote = SMOTE(random_state=42, sampling_strategy=0.5)
X_train_flat = X_train.reshape(X_train.shape[0], -1)
X_train_resampled_flat, y_train_resampled = smote.fit_resample(X_train_flat, y_train)
total_elements = X_train_resampled_flat.size
print(f"X_train_resampled 的元素数量: {total_elements}")
if total_elements % (IMAGE_SIZE * IMAGE_SIZE) == 0:
num_images = total_elements // (IMAGE_SIZE * IMAGE_SIZE)
X_train_resampled = X_train_resampled_flat.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE)
current_image_size = IMAGE_SIZE
else:
valid_size = int(np.sqrt(total_elements))
valid_size = min(valid_size, 512)
while total_elements % (valid_size * valid_size) != 0 and valid_size > 1:
valid_size -= 1
print(f"自动调整为 {valid_size}x{valid_size} 的形状")
num_images = total_elements // (valid_size * valid_size)
X_train_resampled = X_train_resampled_flat.reshape(num_images, valid_size, valid_size)
current_image_size = valid_size
# 统计各类标签数量
total_zeros = np.sum(y == 0)
total_ones = np.sum(y == 1)
train_zeros = np.sum(y_train_resampled == 0)
train_ones = np.sum(y_train_resampled == 1)
val_zeros = np.sum(y_val == 0)
val_ones = np.sum(y_val == 1)
test_zeros = np.sum(y_test == 0)
test_ones = np.sum(y_test == 1)
print(f"全部数据中标签为 0 的数量: {total_zeros}")
print(f"全部数据中标签为 1 的数量: {total_ones}")
print(f"训练集中标签为 0 的数量: {train_zeros}")
print(f"训练集中标签为 1 的数量: {train_ones}")
print(f"验证集中标签为 0 的数量: {val_zeros}")
print(f"验证集中标签为 1 的数量: {val_ones}")
print(f"测试集中标签为 0 的数量: {test_zeros}")
print(f"测试集中标签为 1 的数量: {test_ones}")
# 数据加载器
class MedicalDataset(Dataset):
def __init__(self, images, labels, transform=None, image_size=IMAGE_SIZE):
self.images = images
self.labels = labels
self.transform = transform
self.image_size = image_size
self.xrv_transform = xrv.datasets.XRayCenterCrop()
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img = self.images[idx]
label = self.labels[idx]
if img.ndim == 1:
img = img.reshape(self.image_size, self.image_size)
if img.ndim == 2:
img = img[np.newaxis, ...]
img = self.xrv_transform(img)
img = xrv.datasets.normalize(img, maxval=6000)
img_np = img.squeeze(0)
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
img_pil = to_pil_image(img_np)
if self.transform:
img = self.transform(img_pil)
else:
img = transforms.ToTensor()(img_pil)
if img.ndim == 2:
img = img.unsqueeze(0)
label = torch.tensor(label, dtype=torch.float)
return img, label
# 创建数据加载器
train_loader = DataLoader(
MedicalDataset(X_train_resampled, y_train_resampled, transform=train_transform, image_size=current_image_size),
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=True
)
val_loader = DataLoader(
MedicalDataset(X_val, y_val, transform=val_test_transform, image_size=IMAGE_SIZE),
batch_size=BATCH_SIZE * 2,
num_workers=NUM_WORKERS,
pin_memory=True
)
test_loader = DataLoader(
MedicalDataset(X_test, y_test, transform=val_test_transform, image_size=IMAGE_SIZE),
batch_size=BATCH_SIZE * 2,
num_workers=NUM_WORKERS,
pin_memory=True
)
# 创建校准专用的数据加载器(使用验证集)
calibration_loader = DataLoader(
MedicalDataset(X_val, y_val, transform=val_test_transform, image_size=IMAGE_SIZE),
batch_size=BATCH_SIZE * 2,
num_workers=NUM_WORKERS,
pin_memory=True
)
# 训练配置 - 使用Focal Loss
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
def forward(self, inputs, targets):
bce_loss = self.bce_loss(inputs, targets)
pt = torch.exp(-bce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
return focal_loss.mean()
# 计算类别分布用于设置Focal Loss参数
neg_count = np.sum(y_train == 0)
pos_count = np.sum(y_train == 1)
total_count = len(y_train)
print(f"训练集分布 - 负例: {neg_count}, 正例: {pos_count}, 总计: {total_count}")
# 设置Focal Loss参数
alpha_value = 0.75 # 对于7:3的不平衡,0.75是一个很好的起点
gamma_value = 2.0
# 初始化Focal Loss
criterion = FocalLoss(alpha=alpha_value, gamma=gamma_value)
print(f"使用Focal Loss - alpha: {alpha_value}, gamma: {gamma_value}")
# 模型配置
class CustomDenseNet(nn.Module):
def __init__(self, pretrained=True):
super(CustomDenseNet, self).__init__()
self.base_model = xrv.models.DenseNet(weights="densenet121-res224-chex")
self.features = self.base_model.features
num_ftrs = self.base_model.classifier.in_features
self.classifier = nn.Linear(num_ftrs, 1)
if hasattr(self.base_model, 'op_threshs'):
delattr(self.base_model, 'op_threshs')
def forward(self, x):
x = self.features(x)
x = x.mean(dim=(2, 3))
x = self.classifier(x)
return x
# 初始化模型
model = CustomDenseNet().to(device)
model = nn.DataParallel(model)
model = model.to(device)
params_to_update = [p for p in model.parameters() if p.requires_grad]
print(f"可训练参数数量: {len(params_to_update)}")
optimizer = optim.AdamW(
params_to_update,
lr=5e-5,
weight_decay=0.001
)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
# 早停策略参数
patience = 30
no_improvement_epochs = 0
# 记录指标
train_losses = []
val_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
# 记录最高AUC及相关指标
highest_auc = 0
highest_auc_epoch = 0
highest_auc_accuracy = 0
highest_auc_precision = 0
highest_auc_recall = 0
highest_auc_f1 = 0
highest_auc_sensitivity = 0
highest_auc_specificity = 0
highest_auc_conf_matrix = None
# -------------------- 改进的校准相关函数 --------------------
def improved_plot_calibration_curve(true_labels, probs, n_bins=10, model_name='Model'):
"""改进的校准曲线绘制函数"""
plt.figure(figsize=(12, 8))
# 计算校准曲线
prob_true, prob_pred = calibration_curve(true_labels, probs, n_bins=n_bins, strategy='uniform')
# 计算Brier分数
brier_score = brier_score_loss(true_labels, probs)
# 绘制理想校准曲线
plt.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated")
# 绘制模型校准曲线
plt.plot(prob_pred, prob_true, "s-", label=f'{model_name} (Brier = {brier_score:.4f})',
linewidth=2, markersize=8)
# 添加置信区间
bin_counts = np.histogram(probs, bins=np.linspace(0, 1, n_bins + 1))[0]
# 为每个bin添加样本数量标注
for i, (x, y, count) in enumerate(zip(prob_pred, prob_true, bin_counts)):
if count > 0: # 只为有样本的bin添加标注
plt.annotate(f'n={count}', (x, y), textcoords="offset points",
xytext=(0, 10), ha='center', fontsize=9)
# 添加直方图
plt.hist(probs, range=(0, 1), bins=n_bins, histtype="step", lw=2, alpha=0.7,
label=f'Probability distribution (n={len(probs)})')
plt.xlabel("Mean predicted probability", fontsize=12)
plt.ylabel("Fraction of positives", fontsize=12)
plt.title(f"Calibration Curve - {model_name}", fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
# 添加Brier分数文本框
plt.text(0.02, 0.98, f'Brier Score: {brier_score:.4f}', transform=plt.gca().transAxes,
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
verticalalignment='top', fontsize=11)
plt.tight_layout()
return plt
def calibrate_model(model, calibration_loader, device, method='isotonic'):
"""
校准模型预测概率
Args:
model: 训练好的PyTorch模型
calibration_loader: 用于校准的数据加载器
device: 设备 (CPU/GPU)
method: 校准方法 ('isotonic' 或 'platt')
Returns:
calibrated_predict_proba: 校准后的预测概率函数
"""
model.eval()
# 收集校准数据的预测概率和真实标签
all_probs = []
all_labels = []
with torch.no_grad():
for inputs, labels in calibration_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs).squeeze()
probs = torch.sigmoid(outputs).cpu().numpy()
all_probs.extend(probs)
all_labels.extend(labels.cpu().numpy())
all_probs = np.array(all_probs).reshape(-1, 1)
all_labels = np.array(all_labels)
# 创建校准器
if method == 'isotonic':
calibrator = IsotonicRegression(out_of_bounds='clip')
calibrator.fit(all_probs.flatten(), all_labels)
def calibrated_predict_proba(probs):
return calibrator.predict(probs.flatten()).reshape(-1)
elif method == 'platt':
calibrator = LogisticRegression()
calibrator.fit(all_probs, all_labels)
def calibrated_predict_proba(probs):
return calibrator.predict_proba(probs.reshape(-1, 1))[:, 1]
else:
raise ValueError("method must be 'isotonic' or 'platt'")
return calibrator, calibrated_predict_proba, all_probs, all_labels
def plot_calibration_comparison(true_labels, uncalibrated_probs, calibrated_probs):
"""
绘制校准前后的对比图
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
# 计算Brier分数
brier_uncalibrated = brier_score_loss(true_labels, uncalibrated_probs)
brier_calibrated = brier_score_loss(true_labels, calibrated_probs)
# 绘制未校准的校准曲线
prob_true_uncal, prob_pred_uncal = calibration_curve(true_labels, uncalibrated_probs, n_bins=10)
ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated")
ax1.plot(prob_pred_uncal, prob_true_uncal, "s-", label=f'Uncalibrated (Brier = {brier_uncalibrated:.4f})',
linewidth=2, markersize=8)
ax1.hist(uncalibrated_probs, range=(0, 1), bins=10, histtype="step", lw=2, alpha=0.7)
ax1.set_xlabel("Mean predicted probability")
ax1.set_ylabel("Fraction of positives")
ax1.set_title("Before Calibration", fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 绘制校准后的校准曲线
prob_true_cal, prob_pred_cal = calibration_curve(true_labels, calibrated_probs, n_bins=10)
ax2.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated")
ax2.plot(prob_pred_cal, prob_true_cal, "s-", label=f'Calibrated (Brier = {brier_calibrated:.4f})',
linewidth=2, markersize=8, color='green')
ax2.hist(calibrated_probs, range=(0, 1), bins=10, histtype="step", lw=2, alpha=0.7, color='green')
ax2.set_xlabel("Mean predicted probability")
ax2.set_ylabel("Fraction of positives")
ax2.set_title("After Calibration", fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.suptitle("Model Calibration Comparison", fontsize=16, fontweight='bold')
plt.tight_layout()
return plt
def evaluate_calibration(true_labels, probs):
"""
评估模型校准性能
"""
# 计算Brier分数
brier_score = brier_score_loss(true_labels, probs)
# 计算校准曲线
prob_true, prob_pred = calibration_curve(true_labels, probs, n_bins=10)
# 计算预期校准误差 (ECE)
bin_counts = np.histogram(probs, bins=np.linspace(0, 1, 11))[0]
if len(bin_counts) > len(prob_true):
bin_counts = bin_counts[:len(prob_true)]
elif len(bin_counts) < len(prob_true):
bin_counts = np.pad(bin_counts, (0, len(prob_true) - len(bin_counts)))
ece = np.sum(np.abs(prob_true - prob_pred) * bin_counts) / np.sum(bin_counts)
# 计算最大校准误差 (MCE)
mce = np.max(np.abs(prob_true - prob_pred))
results = {
'brier_score': brier_score,
'expected_calibration_error': ece,
'max_calibration_error': mce,
'prob_true': prob_true,
'prob_pred': prob_pred,
'bin_counts': bin_counts
}
return results
# -------------------- 其他评估函数保持不变 --------------------
def check_preprocessing(model, test_loader):
model.eval()
device = next(model.parameters()).device
for inputs, labels in test_loader:
inputs = inputs.to(device)
print(f"输入图像形状: {inputs.shape}")
print(f"输入图像范围: [{inputs.min()}, {inputs.max()}]")
print(f"输入图像均值: {inputs.mean()}")
print(f"输入图像标准差: {inputs.std()}")
with torch.no_grad():
outputs = model(inputs)
print(f"模型输出形状: {outputs.shape}")
print(f"模型输出范围: [{outputs.min()}, {outputs.max()}]")
break
def mixup_data(x, y, alpha=0.5):
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
index = torch.randperm(batch_size).to(device)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index].to(device)
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
def calculate_metrics(true_labels, predictions, probs):
precision = precision_score(true_labels, predictions)
recall = recall_score(true_labels, predictions)
f1 = f1_score(true_labels, predictions)
conf_matrix = confusion_matrix(true_labels, predictions)
average_auc = roc_auc_score(true_labels, probs)
tn, fp, fn, tp = conf_matrix.ravel()
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
return {
'accuracy': np.mean(np.array(true_labels) == np.array(predictions)),
'precision': precision,
'recall': recall,
'sensitivity': recall,
'specificity': specificity,
'f1': f1,
'auc': average_auc,
'conf_matrix': conf_matrix,
'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp
}
def validate(model, loader):
model.eval()
total_loss = 0
correct = 0
total = 0
all_probs = []
all_labels = []
with torch.no_grad():
for inputs, labels in loader:
inputs = inputs.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).float()
outputs = model(inputs).squeeze()
loss = criterion(outputs, labels)
total_loss += loss.item()
probs = torch.sigmoid(outputs).detach().cpu().numpy()
all_probs.extend(probs)
all_labels.extend(labels.cpu().numpy())
predicted = (probs > 0.5).astype(float)
correct += np.sum(predicted == labels.cpu().numpy())
total += len(labels)
val_auc = roc_auc_score(all_labels, all_probs) if total > 0 else 0.0
accuracy = 100 * correct / total if total > 0 else 0.0
metrics = calculate_metrics(all_labels, (np.array(all_probs) > 0.5).astype(int), all_probs)
return total_loss / len(loader), accuracy, val_auc, metrics
def plot_confusion_matrix(conf_matrix, title='Confusion Matrix'):
plt.figure(figsize=(8, 6))
tn, fp, fn, tp = conf_matrix.ravel()
plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title(title)
plt.colorbar()
classes = ['Negative', 'Positive']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = conf_matrix.max() / 2.
for i in range(conf_matrix.shape[0]):
for j in range(conf_matrix.shape[1]):
plt.text(j, i, format(conf_matrix[i, j], 'd'),
horizontalalignment="center",
color="white" if conf_matrix[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
plt.figtext(0.5, 0.01,
f'Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}',
ha='center', fontsize=12, bbox={"facecolor": "orange", "alpha": 0.2, "pad": 5})
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
def plot_pr_curve(true_labels, probs):
precision, recall, thresholds = precision_recall_curve(true_labels, probs)
pr_auc = auc(recall, precision)
plt.figure(figsize=(10, 8))
plt.plot(recall, precision, color='darkorange', lw=2,
label=f'PR curve (AUC = {pr_auc:.4f})')
plt.plot([0, 1], [np.sum(true_labels) / len(true_labels), np.sum(true_labels) / len(true_labels)],
color='navy', lw=2, linestyle='--',
label=f'No Skill (Precision = {np.sum(true_labels) / len(true_labels):.4f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall (Sensitivity)')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")
plt.grid(True)
plt.savefig('pr_curve.png', dpi=300)
plt.show()
return pr_auc
def plot_dca_curve(true_labels, probs, thresholds=None):
if thresholds is None:
thresholds = np.linspace(0, 1, 100)
event_rate = np.mean(true_labels)
net_benefit_model = []
net_benefit_treat_all = []
net_benefit_treat_none = []
for pt in thresholds:
predicted_high_risk = probs > pt
tp = np.sum((predicted_high_risk == 1) & (true_labels == 1))
fp = np.sum((predicted_high_risk == 1) & (true_labels == 0))
net_benefit = (tp - fp * pt / (1 - pt)) / len(true_labels) if pt < 1 else 0
net_benefit_model.append(net_benefit)
treat_all_benefit = (event_rate - pt / (1 - pt) * (1 - event_rate)) if pt < 1 else 0
net_benefit_treat_all.append(treat_all_benefit)
net_benefit_treat_none.append(0)
plt.figure(figsize=(12, 8))
plt.plot(thresholds, net_benefit_model, label='Model', linewidth=2)
plt.plot(thresholds, net_benefit_treat_all, label='Treat all', linestyle='--', linewidth=2)
plt.plot(thresholds, net_benefit_treat_none, label='Treat none', linestyle=':', linewidth=2)
plt.xlabel('Threshold Probability')
plt.ylabel('Net Benefit')
plt.title('Decision Curve Analysis')
plt.legend()
plt.grid(True)
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.savefig('dca_curve.png', dpi=300)
plt.show()
def plot_comprehensive_evaluation(true_labels, probs):
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
# 1. ROC曲线
fpr, tpr, _ = roc_curve(true_labels, probs)
roc_auc = auc(fpr, tpr)
axes[0, 0].plot(fpr, tpr, color='darkorange', lw=2,
label=f'ROC curve (AUC = {roc_auc:.4f})')
axes[0, 0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
axes[0, 0].set_xlim([0.0, 1.0])
axes[0, 0].set_ylim([0.0, 1.05])
axes[0, 0].set_xlabel('False Positive Rate (1-Specificity)')
axes[0, 0].set_ylabel('True Positive Rate (Sensitivity)')
axes[0, 0].set_title('ROC Curve')
axes[0, 0].legend(loc="lower right")
axes[0, 0].grid(True)
# 2. PR曲线
precision, recall, _ = precision_recall_curve(true_labels, probs)
pr_auc = auc(recall, precision)
axes[0, 1].plot(recall, precision, color='darkorange', lw=2,
label=f'PR curve (AUC = {pr_auc:.4f})')
axes[0, 1].plot([0, 1], [np.sum(true_labels) / len(true_labels),
np.sum(true_labels) / len(true_labels)],
color='navy', lw=2, linestyle='--')
axes[0, 1].set_xlim([0.0, 1.0])
axes[0, 1].set_ylim([0.0, 1.05])
axes[0, 1].set_xlabel('Recall (Sensitivity)')
axes[0, 1].set_ylabel('Precision')
axes[0, 1].set_title('Precision-Recall Curve')
axes[0, 1].legend(loc="lower left")
axes[0, 1].grid(True)
# 3. 改进的校准曲线
prob_true, prob_pred = calibration_curve(true_labels, probs, n_bins=10)
brier_score = brier_score_loss(true_labels, probs)
axes[1, 0].plot([0, 1], [0, 1], "k:", label="Perfectly calibrated")
axes[1, 0].plot(prob_pred, prob_true, "s-", label=f'Model (Brier = {brier_score:.4f})', linewidth=2)
axes[1, 0].set_xlabel("Mean predicted probability")
axes[1, 0].set_ylabel("Fraction of positives")
axes[1, 0].set_title("Calibration Curve")
axes[1, 0].legend()
axes[1, 0].grid(True)
# 4. 预测概率分布
axes[1, 1].hist(probs[true_labels == 0], bins=20, alpha=0.5, label='Negative', density=True)
axes[1, 1].hist(probs[true_labels == 1], bins=20, alpha=0.5, label='Positive', density=True)
axes[1, 1].set_xlabel('Predicted Probability')
axes[1, 1].set_ylabel('Density')
axes[1, 1].set_title('Predicted Probability Distribution')
axes[1, 1].legend()
axes[1, 1].grid(True)
plt.tight_layout()
plt.savefig('comprehensive_evaluation.png', dpi=300)
plt.show()
# -------------------- 模型训练 --------------------
print(model.module)
best_val_auc = 0.0
no_improvement_epochs = 0
patience = 30
for epoch in range(100):
model.train()
train_loss = 0
correct = 0
total = 0
for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
inputs = inputs.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True).float()
optimizer.zero_grad(set_to_none=True)
if random.random() < 0.3:
inputs, labels_a, labels_b, lam = mixup_data(inputs, labels)
outputs = model(inputs).squeeze()
loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
else:
outputs = model(inputs).squeeze()
loss = criterion(outputs, labels)
if torch.isnan(loss):
print("NaN loss detected! Skipping batch.")
continue
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
probs = torch.sigmoid(outputs).detach().cpu().numpy()
predicted = (probs > 0.5).astype(float)
correct += np.sum(predicted == labels.cpu().numpy())
total += len(labels)
train_accuracy = 100 * correct / total if total > 0 else 0.0
train_losses.append(train_loss / len(train_loader))
train_accuracies.append(train_accuracy)
print(f"Epoch {epoch + 1:02d} | Training Accuracy: {train_accuracy:.2f}%")
val_loss, val_accuracy, val_auc, val_metrics = validate(model, val_loader)
val_losses.append(val_loss)
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | Val AUC: {val_auc:.4f}")
print(f"Val Precision: {val_metrics['precision']:.4f} | Val Recall: {val_metrics['recall']:.4f}")
print(f"Val Sensitivity: {val_metrics['sensitivity']:.4f} | Val Specificity: {val_metrics['specificity']:.4f}")
print(f"Val F1-score: {val_metrics['f1']:.4f}")
if val_auc > best_val_auc:
best_val_auc = val_auc
no_improvement_epochs = 0
torch.save(model.module.state_dict(), 'best_xray_densenet_model.pth')
print(f"★ New best model saved (AUC={best_val_auc:.4f})")
else:
no_improvement_epochs += 1
if no_improvement_epochs >= patience:
print(f"Early stopping: No AUC improvement for {patience} epochs.")
break
scheduler.step(val_loss)
test_loss = 0
correct = 0
total = 0
predictions = []
true_labels = []
predicted_probs = []
with torch.no_grad():
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs).squeeze()
loss = criterion(outputs, labels)
test_loss += loss.item()
predicted = (torch.sigmoid(outputs) > 0.5).float()
total += labels.size(0)
correct += (predicted == labels).sum().item()
predictions.extend(predicted.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
probs = torch.sigmoid(outputs).cpu().numpy()
predicted_probs.extend(probs)
test_accuracy = 100 * correct / total
test_losses.append(test_loss / len(test_loader))
test_accuracies.append(test_accuracy)
test_metrics = calculate_metrics(true_labels, predictions, predicted_probs)
print(f"\n=== Epoch {epoch + 1:02d} Test Results ===")
print(f"Test Loss: {test_loss / len(test_loader):.4f}")
print(f"Test Accuracy: {test_metrics['accuracy'] * 100:.2f}%")
print(f"Precision: {test_metrics['precision']:.4f}")
print(f"Recall (Sensitivity): {test_metrics['sensitivity']:.4f}")
print(f"Specificity: {test_metrics['specificity']:.4f}")
print(f"F1-score: {test_metrics['f1']:.4f}")
print(f"AUC: {test_metrics['auc']:.4f}")
print(f"Confusion Matrix:\n{test_metrics['conf_matrix']}")
print(f"TN: {test_metrics['tn']}, FP: {test_metrics['fp']}, FN: {test_metrics['fn']}, TP: {test_metrics['tp']}")
if test_metrics['auc'] > highest_auc:
highest_auc = test_metrics['auc']
highest_auc_epoch = epoch + 1
highest_auc_accuracy = test_metrics['accuracy']
highest_auc_precision = test_metrics['precision']
highest_auc_recall = test_metrics['recall']
highest_auc_f1 = test_metrics['f1']
highest_auc_sensitivity = test_metrics['sensitivity']
highest_auc_specificity = test_metrics['specificity']
highest_auc_conf_matrix = test_metrics['conf_matrix']
# 绘制训练损失和准确率曲线
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.plot(test_losses, label='Test Loss', color='blue')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Training Accuracy', color='orange')
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Training Accuracy Curve')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('training_metrics.png', dpi=300)
plt.show()
# -------------------- 最终评估(包含模型校准) --------------------
print("\n=== Final Evaluation with Best Model ===")
model_path = 'best_xray_densenet_model.pth'
if hasattr(model, 'module'):
model.module.load_state_dict(torch.load(model_path))
else:
model.load_state_dict(torch.load(model_path))
model.eval()
# 获取测试集预测结果
all_test_labels = []
all_test_probs = []
all_test_predictions = []
with torch.no_grad():
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs).squeeze()
probs = torch.sigmoid(outputs).cpu().numpy()
predictions = (probs > 0.5).astype(int)
all_test_labels.extend(labels.cpu().numpy())
all_test_probs.extend(probs)
all_test_predictions.extend(predictions)
# 计算最终指标
final_metrics = calculate_metrics(all_test_labels, all_test_predictions, all_test_probs)
print(f"\n=== Final Test Metrics (Before Calibration) ===")
print(f"Accuracy: {final_metrics['accuracy'] * 100:.2f}%")
print(f"Precision: {final_metrics['precision']:.4f}")
print(f"Recall (Sensitivity): {final_metrics['sensitivity']:.4f}")
print(f"Specificity: {final_metrics['specificity']:.4f}")
print(f"F1-score: {final_metrics['f1']:.4f}")
print(f"AUC: {final_metrics['auc']:.4f}")
print(f"Confusion Matrix:\n{final_metrics['conf_matrix']}")
# -------------------- 模型校准 --------------------
print("\n=== Model Calibration ===")
# 使用验证集进行校准
print("Calibrating model using validation set...")
calibrator, calibrated_predict_proba, cal_probs, cal_labels = calibrate_model(
model, calibration_loader, device, method='isotonic'
)
# 对测试集进行校准
calibrated_test_probs = calibrated_predict_proba(np.array(all_test_probs))
# 计算校准后的指标
calibrated_predictions = (calibrated_test_probs > 0.5).astype(int)
calibrated_metrics = calculate_metrics(all_test_labels, calibrated_predictions, calibrated_test_probs)
print(f"\n=== Test Metrics (After Calibration) ===")
print(f"Accuracy: {calibrated_metrics['accuracy'] * 100:.2f}%")
print(f"Precision: {calibrated_metrics['precision']:.4f}")
print(f"Recall (Sensitivity): {calibrated_metrics['sensitivity']:.4f}")
print(f"Specificity: {calibrated_metrics['specificity']:.4f}")
print(f"F1-score: {calibrated_metrics['f1']:.4f}")
print(f"AUC: {calibrated_metrics['auc']:.4f}")
# 评估校准性能
print("\n=== Calibration Evaluation ===")
before_calibration = evaluate_calibration(all_test_labels, all_test_probs)
after_calibration = evaluate_calibration(all_test_labels, calibrated_test_probs)
print("Before Calibration:")
print(f" Brier Score: {before_calibration['brier_score']:.4f}")
print(f" Expected Calibration Error (ECE): {before_calibration['expected_calibration_error']:.4f}")
print(f" Max Calibration Error (MCE): {before_calibration['max_calibration_error']:.4f}")
print("\nAfter Calibration:")
print(f" Brier Score: {after_calibration['brier_score']:.4f}")
print(f" Expected Calibration Error (ECE): {after_calibration['expected_calibration_error']:.4f}")
print(f" Max Calibration Error (MCE): {after_calibration['max_calibration_error']:.4f}")
# -------------------- 绘制所有评估图表 --------------------
print("\n=== Generating Evaluation Plots ===")
# 1. 综合评估图(包含改进的校准曲线)
plot_comprehensive_evaluation(all_test_labels, all_test_probs)
# 2. 混淆矩阵图(校准后)
plot_confusion_matrix(calibrated_metrics['conf_matrix'], 'Confusion Matrix (After Calibration)')
# 3. ROC曲线
fpr, tpr, thresholds = roc_curve(all_test_labels, calibrated_test_probs)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
best_threshold_idx = np.argmax(tpr - fpr)
best_threshold = thresholds[best_threshold_idx]
plt.plot(fpr[best_threshold_idx], tpr[best_threshold_idx], 'ro',
label=f'Best Threshold = {best_threshold:.4f}')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (1-Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.title('Receiver Operating Characteristic (ROC) Curve - After Calibration')
plt.legend(loc="lower right")
plt.grid(True)
plt.savefig('roc_curve_calibrated.png', dpi=300)
plt.show()
# 4. PR曲线
pr_auc = plot_pr_curve(all_test_labels, calibrated_test_probs)
# 5. 改进的校准曲线(校准后)
plt = improved_plot_calibration_curve(all_test_labels, calibrated_test_probs,
n_bins=10, model_name='Calibrated Model')
plt.savefig('improved_calibration_curve.png', dpi=300, bbox_inches='tight')
plt.close()
# 6. 校准前后对比图
plt = plot_calibration_comparison(all_test_labels, all_test_probs, calibrated_test_probs)
plt.savefig('calibration_comparison.png', dpi=300, bbox_inches='tight')
plt.close()
# 7. DCA曲线(校准后)
plot_dca_curve(all_test_labels, calibrated_test_probs)
# 8. 最佳阈值分析
print(f"\n=== Best Threshold Analysis (After Calibration) ===")
print(f"Best Threshold: {best_threshold:.4f}")
print(f"Sensitivity at Best Threshold: {tpr[best_threshold_idx]:.4f}")
print(f"Specificity at Best Threshold: {1 - fpr[best_threshold_idx]:.4f}")
best_calibrated_predictions = (np.array(calibrated_test_probs) > best_threshold).astype(int)
best_calibrated_metrics = calculate_metrics(all_test_labels, best_calibrated_predictions, calibrated_test_probs)
print(f"Accuracy at Best Threshold: {best_calibrated_metrics['accuracy'] * 100:.2f}%")
print(f"Precision at Best Threshold: {best_calibrated_metrics['precision']:.4f}")
print(f"F1-score at Best Threshold: {best_calibrated_metrics['f1']:.4f}")
# 保存校准器
import joblib
joblib.dump(calibrator, 'model_calibrator.pkl')
print("\nCalibrator saved to 'model_calibrator.pkl'")
print("\nAll evaluation plots have been saved!")其中的特异度特别低