import os
import gc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from glob import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged,
RandCropByPosNegLabeld, RandFlipd, RandRotate90d, EnsureTyped,
Resized, RandZoomd, RandGaussianNoised, CenterSpatialCropd,
Activations, AsDiscrete, RandCoarseDropoutd, RandBiasFieldd
)
from monai.data import PersistentDataset, list_data_collate, decollate_batch
from monai.networks.nets import SwinUNETR
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss, FocalLoss
# ================ 内存优化配置 ================
# 设置环境变量减少内存碎片
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# 启用cudnn基准测试加速但增加内存,根据GPU内存大小选择
torch.backends.cudnn.benchmark = True # 如果内存不足可设为False
# ========================== 参数配置 ==========================
root_dir = "datasets/LiTS/processed"
images = sorted(glob(os.path.join(root_dir, "images", "*.nii.gz")))
labels = sorted(glob(os.path.join(root_dir, "labels", "*.nii.gz")))
data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)]
# 内存优化:使用更小的验证集比例
train_files, val_files = train_test_split(data, test_size=0.15, random_state=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 训练参数
max_epochs = 200
batch_size = 1 # 3D模型batch_size保持1
num_classes = 3
learning_rate = 1e-4
clip_dim = 512
use_amp = True # 启用混合精度减少内存使用
accumulation_steps = 4 # 梯度累积步数,模拟更大batch size
# 图像尺寸 - 调整到更小的尺寸以节省内存
base_size = (96, 96, 48) # 原始(128,128,64)
crop_size = (64, 64, 32) # 原始(64,64,32)
print(f"使用尺寸: crop={crop_size}")
# ===================== 内存友好的数据预处理 =====================
train_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
# 移除Resized步骤直接裁剪,减少内存使用
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=crop_size,
pos=1.0,
neg=1.0,
num_samples=1
),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.1, mode=("trilinear", "nearest")), # 缩小缩放范围
RandGaussianNoised(keys=["image"], prob=0.2, mean=0.0, std=0.05), # 减小噪声幅度
# 添加内存友好的高级增强
RandCoarseDropoutd(
keys=["image"],
holes=5, # 减少空洞数量
spatial_size=(10, 10, 5), # 减小空洞尺寸
max_holes=8,
prob=0.2,
fill_value=0
),
RandBiasFieldd(
keys=["image"],
coeff_range=(0.05, 0.15), # 减小偏置场强度
prob=0.1
),
EnsureTyped(keys=["image", "label"]),
])
val_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
CenterSpatialCropd(keys=["image", "label"], roi_size=crop_size), # 直接裁剪
EnsureTyped(keys=["image", "label"]),
])
# 使用PersistentDataset并限制缓存大小
os.makedirs("./cache/train", exist_ok=True)
os.makedirs("./cache/val", exist_ok=True)
train_ds = PersistentDataset(
train_files,
transform=train_transforms,
cache_dir="./cache/train",
cache_rate=0.6 # 只缓存60%的数据以减少内存
)
val_ds = PersistentDataset(
val_files,
transform=val_transforms,
cache_dir="./cache/val",
cache_rate=1.0 # 验证集完全缓存
)
# 数据加载器 - 减少num_workers节省内存
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
collate_fn=list_data_collate,
num_workers=2, # 减少worker数量
pin_memory=True
)
val_loader = DataLoader(
val_ds,
batch_size=1,
shuffle=False,
collate_fn=list_data_collate,
num_workers=1, # 减少worker数量
pin_memory=True
)
# =============== 加载文本特征 ===============
# 内存优化:使用内存映射加载大文件
clip_feats = np.load("./clip_text_features.npy", mmap_mode='r')
clip_feats_tensor = torch.from_numpy(np.array(clip_feats)).float().to(device)
def get_text_features(bs):
"""内存友好的文本特征获取"""
idx = torch.randint(0, len(clip_feats), (bs,))
# 使用索引直接从内存映射中获取
return torch.tensor(clip_feats[idx]).float().to(device)
# =============== 融合模块定义 ===============
class MemoryEfficientCrossAttention(nn.Module):
"""内存优化的交叉注意力模块"""
def __init__(self, img_dim=192, text_dim=512, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = img_dim // num_heads
# 使用更小的线性层
self.qkv = nn.Linear(img_dim, img_dim * 3, bias=False)
self.text_proj = nn.Linear(text_dim, img_dim)
self.out = nn.Linear(img_dim, img_dim)
def forward(self, img_feat, text_feat):
B, C, D, H, W = img_feat.shape
N = D * H * W
img_flat = img_feat.view(B, C, N).permute(0, 2, 1) # (B, N, C)
# 多头注意力机制
qkv = self.qkv(img_flat).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, num_heads, N, head_dim)
# 文本特征处理
text_feat = self.text_proj(text_feat).view(B, 1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, num_heads, 1, head_dim)
# 注意力计算 - 使用缩放点积
attn = torch.matmul(q, text_feat.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = torch.softmax(attn, dim=-1)
# 上下文向量
context = torch.matmul(attn, v) # (B, num_heads, N, head_dim)
context = context.transpose(1, 2).contiguous().view(B, N, C) # (B, N, C)
# 输出投影
out = self.out(context).permute(0, 2, 1).view(B, C, D, H, W)
return img_feat + out
# =============== 主模型定义 ===============
class EfficientSwinUNETR(SwinUNETR):
"""内存优化的SwinUNETR变体"""
def __init__(self, img_size, in_channels, out_channels, feature_size=12, text_feat_dim=512):
super().__init__(
img_size=img_size,
in_channels=in_channels,
out_channels=out_channels,
feature_size=feature_size,
depths=(2, 2, 2, 2), # 减少层数节省内存
num_heads=(3, 6, 12, 24) # 减少头数
)
# 添加多尺度融合模块
self.fusion_low = MemoryEfficientCrossAttention(img_dim=feature_size*4, text_dim=text_feat_dim)
self.fusion_mid = MemoryEfficientCrossAttention(img_dim=feature_size*8, text_dim=text_feat_dim)
self.fusion_high = MemoryEfficientCrossAttention(img_dim=feature_size*16, text_dim=text_feat_dim)
# 深度监督输出
self.aux_out1 = nn.Conv3d(feature_size*8, out_channels, kernel_size=1)
self.aux_out2 = nn.Conv3d(feature_size*4, out_channels, kernel_size=1)
def forward(self, x, text_feat=None):
# 获取编码器输出
enc_out = self.swinViT(x) # [x0, x1, x2, x3, x4]
# 多尺度融合
if text_feat is not None:
if text_feat.dim() == 1:
text_feat = text_feat.unsqueeze(0)
enc_out[2] = self.fusion_low(enc_out[2], text_feat) # 低层特征融合
enc_out[3] = self.fusion_mid(enc_out[3], text_feat) # 中层特征融合
enc_out[4] = self.fusion_high(enc_out[4], text_feat) # 高层特征融合
# 原始解码器
dec_out = super().forward(x)
# 深度监督输出
aux1 = self.aux_out1(enc_out[3])
aux2 = self.aux_out2(enc_out[2])
# 上采样辅助输出到原始尺寸
aux1 = F.interpolate(aux1, size=x.shape[2:], mode='trilinear', align_corners=False)
aux2 = F.interpolate(aux2, size=x.shape[2:], mode='trilinear', align_corners=False)
return dec_out, aux1, aux2
# =============== 模型训练相关 ===============
# 初始化模型
model = EfficientSwinUNETR(
img_size=crop_size,
in_channels=1,
out_channels=num_classes,
feature_size=10, # 减少特征大小节省内存
text_feat_dim=clip_dim
).to(device)
# 内存优化:梯度检查点 - 减少内存峰值
for module in model.modules():
if hasattr(module, 'set_grad_checkpointing'):
module.set_grad_checkpointing(True)
# 混合损失函数
class CombinedLoss(nn.Module):
"""组合Dice、交叉熵和Focal损失"""
def __init__(self, weights=[0.7, 0.2, 0.1]):
super().__init__()
self.dice_ce = DiceCELoss(
to_onehot_y=True,
softmax=True,
include_background=True,
weight=torch.tensor([0.2, 0.3, 0.5]).to(device)
)
self.focal = FocalLoss(to_onehot_y=True, gamma=2.0)
self.weights = weights
def forward(self, outputs, target):
main_out, aux1, aux2 = outputs
# 主输出损失
loss_main = self.dice_ce(main_out, target) + self.focal(main_out, target)
# 辅助输出损失
loss_aux1 = self.dice_ce(aux1, target) + self.focal(aux1, target)
loss_aux2 = self.dice_ce(aux2, target) + self.focal(aux2, target)
# 加权组合
total_loss = (
self.weights[0] * loss_main +
self.weights[1] * loss_aux1 +
self.weights[2] * loss_aux2
)
return total_loss
loss_fn = CombinedLoss().to(device)
# 优化器和学习率调度
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=1e-5,
betas=(0.9, 0.98) # 调整beta减少内存波动
)
scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=20, # 每20个epoch重置一次
T_mult=1, # 保持周期不变
eta_min=1e-6
)
# 评估相关
post_pred = Compose([
Activations(softmax=True),
AsDiscrete(argmax=True, to_onehot=num_classes)
])
post_label = Compose([
AsDiscrete(to_onehot=num_classes)
])
dice_metric = DiceMetric(
include_background=True,
reduction="mean",
get_not_nans=False,
num_classes=num_classes
)
scaler = GradScaler(enabled=use_amp)
# 训练状态跟踪
best_dice = -1
best_epoch = 0
no_improve_counter = 0
patience = 12 # 12个epoch无改进则停止
os.makedirs("optimized_checkpoints", exist_ok=True)
# =============== 内存友好的训练循环 ===============
for epoch in range(1, max_epochs + 1):
print(f"\nEpoch {epoch}/{max_epochs}")
model.train()
epoch_loss = 0
optimizer.zero_grad()
# 训练阶段 - 使用梯度累积
for step, batch in enumerate(tqdm(train_loader, desc="Train")):
images = batch["image"].to(device, non_blocking=True)
labels = batch["label"].to(device, non_blocking=True)
text_feat = get_text_features(images.shape[0])
with autocast(enabled=use_amp):
outputs = model(images, text_feat)
loss = loss_fn(outputs, labels)
loss = loss / accumulation_steps # 梯度累积缩放损失
# 反向传播
scaler.scale(loss).backward()
# 梯度累积:每accumulation_steps步更新一次
if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_loader):
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# 手动内存清理
if step % 10 == 0:
torch.cuda.empty_cache()
gc.collect()
epoch_loss += loss.item() * accumulation_steps
# 计算平均训练损失
avg_train_loss = epoch_loss / len(train_loader)
current_lr = optimizer.param_groups[0]['lr']
print(f"Train Loss: {avg_train_loss:.4f} | LR: {current_lr:.2e}")
# 更新学习率
scheduler.step()
# 验证阶段
model.eval()
val_dices = []
with torch.no_grad():
for batch in tqdm(val_loader, desc="Val"):
images = batch["image"].to(device, non_blocking=True)
labels = batch["label"].to(device, non_blocking=True)
text_feat = get_text_features(images.shape[0])
with autocast(enabled=use_amp):
outputs, _, _ = model(images, text_feat) # 只使用主输出
# 后处理和指标计算
outputs_list = decollate_batch(outputs)
labels_list = decollate_batch(labels)
outputs_convert = [post_pred(o) for o in outputs_list]
labels_convert = [post_label(l) for l in labels_list]
dice_metric(y_pred=outputs_convert, y=labels_convert)
val_dices.append(dice_metric.aggregate().item())
dice_metric.reset()
# 手动内存清理
torch.cuda.empty_cache()
gc.collect()
avg_dice = np.mean(val_dices)
print(f"Val Dice: {avg_dice:.4f}")
# 早停机制和模型保存
if avg_dice > best_dice:
best_dice = avg_dice
best_epoch = epoch
no_improve_counter = 0
torch.save(
model.state_dict(),
f"optimized_checkpoints/best_model_epoch{epoch}_dice{avg_dice:.4f}.pth"
)
print(f"✅ 保存最佳模型 @ epoch {epoch} | Dice: {avg_dice:.4f}")
else:
no_improve_counter += 1
print(f"⏳ 未改进次数: {no_improve_counter}/{patience}")
if no_improve_counter >= patience:
print(f"🛑 早停触发! 最佳Dice: {best_dice:.4f} @ epoch {best_epoch}")
break
# 定期保存检查点但限制数量
if epoch % 10 == 0:
# 只保留最新的3个检查点
checkpoint_files = glob("optimized_checkpoints/checkpoint_*.pth")
checkpoint_files.sort(key=os.path.getmtime)
for old_checkpoint in checkpoint_files[:-3]:
os.remove(old_checkpoint)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_train_loss,
'dice': avg_dice
}, f"optimized_checkpoints/checkpoint_epoch{epoch}.pth")
# 每5个epoch进行一次完整内存清理
if epoch % 5 == 0:
torch.cuda.empty_cache()
gc.collect()
print("训练完成!")这份代码报错啦(covid_seg) (base) liulicheng@ailab-MS-7B79:~/MultiModal_MedSeg_2025$ /home/liulicheng/anaconda3/envs/covid_seg/bin/python /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clip_multiscale_fusion.py
使用尺寸: crop=(64, 64, 32)
Traceback (most recent call last):
File "/home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clip_multiscale_fusion.py", line 108, in <module>
train_ds = PersistentDataset(
TypeError: __init__() got an unexpected keyword argument 'cache_rate'
最新发布