# 文件:train_swinunetr_earlyfusion.py
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from glob import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import matplotlib as mpl
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged,
RandCropByPosNegLabeld, RandFlipd, RandRotate90d, EnsureTyped,
Resized, RandZoomd, RandGaussianNoised, CenterSpatialCropd,
Activations, AsDiscrete
)
from monai.data import PersistentDataset, list_data_collate, decollate_batch
from monai.networks.nets import SwinUNETR
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss
mpl.rcParams['font.sans-serif'] = ['DejaVu Sans']
mpl.rcParams['axes.unicode_minus'] = False
# ========================== 参数配置 ==========================
root_dir = "datasets/LiTS/processed"
images = sorted(glob(os.path.join(root_dir, "images", "*.nii.gz")))
labels = sorted(glob(os.path.join(root_dir, "labels", "*.nii.gz")))
data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)]
train_files, val_files = train_test_split(data, test_size=0.2, random_state=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_epochs = 200
batch_size = 1
num_classes = 3
learning_rate = 1e-4
clip_dim = 512
use_amp = False
base_size = (128, 128, 64)
crop_size = tuple([max(32, s // 32 * 32) for s in (64, 64, 32)])
resized_size = tuple([max(32, s // 32 * 32) for s in base_size])
print(f"使用尺寸: resized={resized_size}, crop={crop_size}")
# ===================== 数据预处理 =====================
train_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")),
RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1.0, neg=1.0, num_samples=1),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, mode=("trilinear", "nearest")),
RandGaussianNoised(keys=["image"], prob=0.2, mean=0.0, std=0.1),
EnsureTyped(keys=["image", "label"]),
])
val_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")),
CenterSpatialCropd(keys=["image", "label"], roi_size=crop_size),
EnsureTyped(keys=["image", "label"]),
])
train_ds = PersistentDataset(train_files, transform=train_transforms, cache_dir="./cache/train")
val_ds = PersistentDataset(val_files, transform=val_transforms, cache_dir="./cache/val")
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=list_data_collate, num_workers=2, pin_memory=True)
# =============== 加载文本特征 ===============
clip_feats = np.load("./clip_text_features.npy") # shape: (N, 512)
clip_feats = torch.from_numpy(clip_feats).float()
def get_text_features(bs):
idx = torch.randint(0, len(clip_feats), (bs,))
return clip_feats[idx].to(device)
# =============== 融合模块定义 ===============
class EarlyFusionCrossAttention(nn.Module):
def __init__(self, img_dim=192, text_dim=512):
super().__init__()
self.query = nn.Linear(img_dim, img_dim)
self.key = nn.Linear(text_dim, img_dim)
self.value = nn.Linear(text_dim, img_dim)
self.out = nn.Linear(img_dim, img_dim)
def forward(self, img_feat, text_feat):
B, C, D, H, W = img_feat.shape
N = D * H * W
img_flat = img_feat.view(B, C, N).permute(0, 2, 1) # (B, N, C)
Q = self.query(img_flat)
K = self.key(text_feat).unsqueeze(1)
V = self.value(text_feat).unsqueeze(1)
attn = torch.softmax(Q @ K.transpose(-2, -1) / C ** 0.5, dim=-1)
out = attn @ V # (B, N, C)
out = self.out(out).permute(0, 2, 1).view(B, C, D, H, W)
return img_feat + out
# =============== 主模型定义 (修复部分) ===============
class SwinUNETRWithEarlyFusion(SwinUNETR):
def __init__(self, img_size, in_channels, out_channels, feature_size=12, text_feat_dim=512):
super().__init__(
img_size=img_size,
in_channels=in_channels,
out_channels=out_channels,
feature_size=feature_size
)
# 添加融合模块
self.fusion = EarlyFusionCrossAttention(
img_dim=feature_size * 16, # 匹配最后一层特征图的通道数
text_dim=text_feat_dim
)
def forward(self, x, text_feat=None):
# 获取编码器输出
enc_out = self.swinViT(x) # 返回列表 [x0, x1, x2, x3, x4]
# 对最后一层特征进行融合
if text_feat is not None:
enc_out[-1] = self.fusion(enc_out[-1], text_feat)
# 使用解码器处理融合后的特征
return self.forward_up(enc_out) # 使用父类的forward_up方法
# =============== 模型训练相关 ===============
model = SwinUNETRWithEarlyFusion(
img_size=crop_size,
in_channels=1,
out_channels=num_classes,
feature_size=12,
text_feat_dim=clip_dim
).to(device)
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True, include_background=True, weight=torch.tensor([0.2, 0.3, 0.5]).to(device))
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode="max", patience=5, factor=0.5, verbose=True, min_lr=1e-6)
post_pred = Compose([Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=num_classes)])
post_label = Compose([AsDiscrete(to_onehot=num_classes)])
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes)
scaler = GradScaler(enabled=use_amp)
best_dice = -1
os.makedirs("earlyfusion_checkpoints", exist_ok=True)
# =============== 训练循环 ===============
for epoch in range(1, max_epochs + 1):
print(f"\nEpoch {epoch}/{max_epochs}")
model.train()
epoch_loss = 0
for batch in tqdm(train_loader, desc="Train"):
images = batch["image"].to(device)
labels = batch["label"].to(device)
text_feat = get_text_features(images.shape[0])
optimizer.zero_grad()
with autocast(enabled=use_amp):
outputs = model(images, text_feat)
loss = loss_fn(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
print(f"Train Loss: {epoch_loss / len(train_loader):.4f}")
# 验证
model.eval()
val_dices = []
with torch.no_grad():
for batch in tqdm(val_loader, desc="Val"):
images = batch["image"].to(device)
labels = batch["label"].to(device)
text_feat = get_text_features(images.shape[0])
with autocast(enabled=use_amp):
outputs = model(images, text_feat)
outputs_list = decollate_batch(outputs)
labels_list = decollate_batch(labels)
outputs_convert = [post_pred(o) for o in outputs_list]
labels_convert = [post_label(l) for l in labels_list]
dice_metric(y_pred=outputs_convert, y=labels_convert)
val_dices.append(dice_metric.aggregate().item())
dice_metric.reset()
avg_dice = np.mean(val_dices)
print(f"Val Dice: {avg_dice:.4f}")
scheduler.step(avg_dice)
# 保存最优模型
if avg_dice > best_dice:
best_dice = avg_dice
torch.save(model.state_dict(), f"earlyfusion_checkpoints/best_model_epoch{epoch}_dice{avg_dice:.4f}.pth")
print(f"✅ Saved best model at epoch {epoch} with Dice {avg_dice:.4f}")这个代码报错啦 return self.forward_up(enc_out) # 使用父类的forward_up方法
File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'SwinUNETRWithEarlyFusion' object has no attribute 'forward_up'