def build_data_dict(paths, labels):
data = []
skipped = 0
for path, label in zip(paths, labels):
try:
base_name = os.path.basename(path)
file_id = base_name.replace(".nii.gz", "")
mask_path = path.replace(".nii.gz", "mask.nii.gz")
if not os.path.exists(mask_path):
print(f"❌ 缺失 mask 文件: {mask_path}")
skipped += 1
continue
img_nii = nib.load(path)
mask_nii = nib.load(mask_path)
if np.all(mask_nii.get_fdata() == 0):
print(f"⚠️ 掩膜全零,跳过: {mask_path}")
skipped += 1
continue
data.append({
"image": path,
"mask": mask_path,
"label": label,
"id": file_id,
"original_affine": np.array(img_nii.affine)[:4, :4].astype(np.float32),
"original_shape": img_nii.shape,
"mask_original_affine": np.array(mask_nii.affine)[:4, :4].astype(np.float32)
})
except Exception as e:
print(f"❌ 构建失败: {path},原因: {e}")
skipped += 1
print(f"✅ 构建完成,有效样本: {len(data)},跳过: {skipped}")
return data
class SyncAffined(MapTransform):
def __init__(self, keys, atol=1e-2, logger=None):
super().__init__(keys)
self.orientation = Orientationd(keys=keys, axcodes="RAS")
self.resample = ResampleToMatchd(keys=["mask"], key_dst="image", mode="nearest")
self.atol = atol
self.logger = logger
def __call__(self, data):
try:
data = self.orientation(data)
a1 = data["image_meta_dict"]["affine"]
a2 = data["mask_meta_dict"]["affine"]
if isinstance(a1, torch.Tensor): a1 = a1.numpy()
if isinstance(a2, torch.Tensor): a2 = a2.numpy()
if not np.allclose(a1, a2, atol=self.atol):
data = self.resample(data)
return data
except Exception as e:
if self.logger:
self.logger.error(f"Error during SyncAffined processing: {e}")
raise
def get_transforms():
deterministic_transforms = Compose([
LoadImaged(keys=["image", "mask"], image_only=False, reader="ITKReader"),
EnsureChannelFirstd(keys=["image", "mask"]),
SyncAffined(keys=["image", "mask"], atol=1e-2),
Spacingd(keys=["image", "mask"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
CropForegroundd(keys=["image", "mask"], source_key="mask", margin=10, allow_smaller=True),
ResizeWithPadOrCropd(keys=["image", "mask"], spatial_size=(64, 64, 64)),
ScaleIntensityRanged(keys=["image"], a_min=20, a_max=80, b_min=0.0, b_max=1.0, clip=True),
ToTensord(keys=["image", "mask"])
])
augmentation_transforms = Compose([
RandFlipd(keys=["image", "mask"], prob=0.2, spatial_axis=[0, 1, 2]),
RandAffined(
keys=["image", "mask"],
prob=0.3,
rotate_range=(-0.2, 0.2),
scale_range=(0.8, 1.2),
shear_range=(-0.1, 0.1, -0.1, 0.1, -0.1, 0.1),
translate_range=(5, 5, 5),
mode=("bilinear", "nearest"),
padding_mode="border",
spatial_size=(64, 64, 64)
),
Lambdad(keys=["label"], func=lambda x: torch.tensor(x, dtype=torch.long).squeeze(0))
])
return deterministic_transforms, augmentation_transforms
deterministic_transforms, augmentation_transforms = get_transforms() data_dir = "D:/monaisj/train"
class_dirs = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
image_paths, labels = [], []
for class_name in class_dirs:
class_path = os.path.join(data_dir, class_name)
nii_files = glob.glob(os.path.join(class_path, "*.nii.gz"))
for nii_file in nii_files:
if 'mask' not in nii_file:
image_paths.append(nii_file)
labels.append(int(class_name))
# 分层划分
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
train_indices, val_indices = next(sss.split(image_paths, labels))
train_paths = [image_paths[i] for i in train_indices]
val_paths = [image_paths[i] for i in val_indices]
train_labels = [labels[i] for i in train_indices]
val_labels = [labels[i] for i in val_indices]
train_files = build_data_dict(train_paths, train_labels)
val_files = build_data_dict(val_paths, val_labels)
# -------------------- 数据集加载 --------------------
train_ds = CacheDataset(data=train_files, transform=deterministic_transforms, cache_rate=0.8)
train_ds = Dataset(train_ds, transform=augmentation_transforms)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=0)
val_ds = CacheDataset(data=val_files, transform=deterministic_transforms, cache_rate=1.0)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.resnet18(pretrained=False, spatial_dims=3, n_input_channels=1, num_classes=2).to(device)
weights_path = "D:/MedicalNet/pretrain/resnet_50_epoch_110_batch_0.pth"
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location=device, weights_only=True)
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace("model.", "").replace("module.", "")
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
print(f"加载权重成功: {weights_path}")
else:
print(f"权重文件不存在: {weights_path}")
def init_weights(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_normal_(m.weight)
torch.nn.init.constant_(m.bias, 0.0)
model.fc.apply(init_weights)
class_weights = torch.tensor([
len(train_labels)/(2.0 * np.bincount(train_labels)[0]),
len(train_labels)/(2.0 * np.bincount(train_labels)[1])
], dtype=torch.float32).to(device)
loss_fn = CrossEntropyLoss(weight=class_weights)
def compute_metrics(labels, preds, probs):
cm = confusion_matrix(labels, preds)
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
accuracy = (tp + tn) / (tp + tn + fp + fn)
auc = roc_auc_score(labels, probs)
return sensitivity, specificity, accuracy, auc
def freeze_layers(model, freeze_patterns=None, unfreeze_patterns=None):
"""按模式冻结/解冻层"""
for name, param in model.named_parameters():
param.requires_grad = False
if unfreeze_patterns:
for pattern in unfreeze_patterns:
if pattern in name:
param.requires_grad = True
break
if "fc" in name:
param.requires_grad = True
def train_model(model, stage=1, epochs=50, init_lr=1e-3, eta_min=1e-5, data_dir="D:/monaisj/"):
# 阶段配置
if stage == 1:
# 阶段1:仅训练分类头
freeze_layers(model, unfreeze_patterns=["fc"])
print("🔒 阶段1:冻结骨干,仅训练分类头")
lr = init_lr # 较高学习率
elif stage == 2:
# 阶段2:解冻高层
freeze_layers(model, unfreeze_patterns=["layer3","layer4", "fc"])
print("🔓 阶段2:解冻高层(layer3/layer4)")
lr = init_lr * 0.1 # 降低学习率
elif stage == 3:
# 阶段3:全解冻
for param in model.parameters():
param.requires_grad = True
print("🔥 阶段3:解冻全网络")
lr = init_lr * 0.01 # 更低学习率
else:
raise ValueError(f"未知阶段: {stage}")
# 创建优化器(仅优化需要梯度的参数)
params_to_optimize = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(params_to_optimize, lr=lr, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=eta_min)
best_val_auc = 0.0
# 初始化历史记录
history = {
'train_loss': [], 'train_acc': [], 'train_auc': [],
'train_sensitivity': [], 'train_specificity': [],
'val_loss': [], 'val_acc': [], 'val_auc': [],
'val_sensitivity': [], 'val_specificity': [],
'train_true_labels': None, 'train_probs': [],
'val_true_labels': None, 'val_probs': []
}
for epoch in range(epochs):
# ================== 训练阶段 ==================
model.train()
epoch_loss = 0.0
train_preds, train_labels, train_probs = [], [], []
scaler = GradScaler()
for batch in train_loader:
images = batch["image"].to(device, non_blocking=True)
labels = batch["label"].long().to(device)
optimizer.zero_grad(set_to_none=True)
outputs = model(images)
loss = loss_fn(outputs, labels)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item() * images.size(0)
# 收集训练数据
preds = torch.argmax(outputs, dim=1).cpu().numpy()
probs = torch.softmax(outputs, dim=1)[:, 1].detach().cpu().numpy()
train_preds.extend(preds)
train_labels.extend(labels.cpu().numpy())
train_probs.extend(probs)
# 记录训练集标签(仅第一次)
if epoch == 0:
history['train_true_labels'] = train_labels
history['train_probs'].append(train_probs)
# 计算训练指标
train_loss = epoch_loss / len(train_loader)
train_acc = accuracy_score(train_labels, train_preds)
train_sensitivity, train_specificity, _, train_auc = compute_metrics(
train_labels, train_preds, train_probs
)
# ================== 验证阶段 ==================
model.eval()
val_loss = 0.0
val_preds, val_labels, val_probs = [], [], []
with torch.no_grad():
for batch in val_loader:
images = batch["image"].to(device, non_blocking=True)
labels = batch["label"].long().to(device)
outputs = model(images)
loss = loss_fn(outputs, labels)
val_loss += loss.item() * images.size(0)
# 收集验证数据
probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
preds = torch.argmax(outputs, dim=1).cpu().numpy()
val_preds.extend(preds)
val_labels.extend(labels.cpu().numpy())
val_probs.extend(probs)
# 记录验证集标签(仅第一次)
if epoch == 0:
history['val_true_labels'] = val_labels
history['val_probs'].append(val_probs)
# 计算验证指标
val_loss = val_loss / len(val_loader)
val_acc = accuracy_score(val_labels, val_preds)
val_sensitivity, val_specificity, _, val_auc = compute_metrics(
val_labels, val_preds, val_probs
)
# ================== 更新学习率和保存模型 ==================
scheduler.step()
# 保存每个epoch的模型
save_path = os.path.join(data_dir, f"best_model_stage{stage}_epoch{epoch+1}.pth")
torch.save(model.state_dict(), save_path)
print(f"✅ 保存模型到 {save_path}")
# ================== 记录历史数据 ==================
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['train_auc'].append(train_auc)
history['train_sensitivity'].append(train_sensitivity)
history['train_specificity'].append(train_specificity)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
history['val_auc'].append(val_auc)
history['val_sensitivity'].append(val_sensitivity)
history['val_specificity'].append(val_specificity)
# 打印进度
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch+1}/{epochs} [Stage {stage}]")
print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | AUC: {train_auc:.4f}")
print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | AUC: {val_auc:.4f}")
print(f"当前学习率: {current_lr:.2e}")
print("-"*50)
return history SEED =123456
random.seed(SEED)
np.random.seed(SEED)
# 阶段一:冻结骨干,仅训练分类头
print("\n开始阶段一训练(冻结骨干网络)")
stage1_history = train_model(model, stage=1, epochs=40, init_lr=1e-2, eta_min=1e-2) print("\n开始阶段二训练(解冻高层)")
# 加载阶段一最佳模型
# best_stage1_epoch = np.argmax([auc for auc in stage1_history['val_auc']]) + 1
model.load_state_dict(torch.load(f"D:/monaisj/best_model_stage1_epoch19.pth",map_location=device, weights_only=True))
stage2_history = train_model(model, stage=2, epochs=30, init_lr=1e-3, eta_min=1e-4) model.load_state_dict(torch.load(f"D:/monaisj/best_model_stage2_epoch20.pth",map_location=device, weights_only=True))
stage3_history = train_model(model, stage=3, epochs=30, init_lr=1e-4, eta_min=1e-5) 加载最佳模型,写一段代码帮我生成grad—cam热图,a图为预处理后根据掩膜裁剪后的输入模型图,b图是掩膜图,c图是cam图,d图是输入模型图和cam的叠加图