交叉熵损失函数reduction=‘none‘会导致程序报错

在使用nn.CrossEntropyLoss时,设置reduction=none会导致图像无法绘制,即使尝试设置为sum也无效。然而,将reduction参数改为mean后,代码可以正常运行。这表明在处理损失函数时,平均化操作对于避免绘图错误是必要的。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# 损失函数 交叉熵
loss = nn.CrossEntropyLoss(reduction='none')

使用上述代码时,图像无法绘制(值设为sum也不可以),将值改为mean之后可以正常跑通

# 损失函数 交叉熵
loss = nn.CrossEntropyLoss(reduction='mean')
import os import gc import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from glob import glob from tqdm import tqdm from sklearn.model_selection import train_test_split from torch.cuda.amp import GradScaler, autocast from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd, RandRotate90d, EnsureTyped, Resized, RandZoomd, RandGaussianNoised, CenterSpatialCropd, Activations, AsDiscrete, RandCoarseDropoutd, RandBiasFieldd ) from monai.data import PersistentDataset, list_data_collate, decollate_batch from monai.networks.nets import SwinUNETR from monai.metrics import DiceMetric from monai.losses import DiceCELoss, FocalLoss # ================ 内存优化配置 ================ # 设置环境变量减少内存碎片 os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # 启用cudnn基准测试加速但增加内存,根据GPU内存大小选择 torch.backends.cudnn.benchmark = True # 如果内存不足可设为False # ========================== 参数配置 ========================== root_dir = "datasets/LiTS/processed" images = sorted(glob(os.path.join(root_dir, "images", "*.nii.gz"))) labels = sorted(glob(os.path.join(root_dir, "labels", "*.nii.gz"))) data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)] # 内存优化:使用更小的验证集比例 train_files, val_files = train_test_split(data, test_size=0.15, random_state=42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练参数 max_epochs = 200 batch_size = 1 # 3D模型batch_size保持1 num_classes = 3 learning_rate = 1e-4 clip_dim = 512 use_amp = True # 启用混合精度减少内存使用 accumulation_steps = 4 # 梯度累积步数,模拟更大batch size # 图像尺寸 - 调整到更小的尺寸以节省内存 base_size = (96, 96, 48) # 原始(128,128,64) crop_size = (64, 64, 32) # 原始(64,64,32) print(f"使用尺寸: crop={crop_size}") # ===================== 内存友好的数据预处理 ===================== train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), # 移除Resized步骤直接裁剪,减少内存使用 RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1.0, neg=1.0, num_samples=1 ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.1, mode=("trilinear", "nearest")), # 缩小缩放范围 RandGaussianNoised(keys=["image"], prob=0.2, mean=0.0, std=0.05), # 减小噪声幅度 # 添加内存友好的高级增强 RandCoarseDropoutd( keys=["image"], holes=5, # 减少空洞数量 spatial_size=(10, 10, 5), # 减小空洞尺寸 max_holes=8, prob=0.2, fill_value=0 ), RandBiasFieldd( keys=["image"], coeff_range=(0.05, 0.15), # 减小偏置场强度 prob=0.1 ), EnsureTyped(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), CenterSpatialCropd(keys=["image", "label"], roi_size=crop_size), # 直接裁剪 EnsureTyped(keys=["image", "label"]), ]) # 使用PersistentDataset并限制缓存大小 os.makedirs("./cache/train", exist_ok=True) os.makedirs("./cache/val", exist_ok=True) train_ds = PersistentDataset( train_files, transform=train_transforms, cache_dir="./cache/train", cache_rate=0.6 # 只缓存60%的数据以减少内存 ) val_ds = PersistentDataset( val_files, transform=val_transforms, cache_dir="./cache/val", cache_rate=1.0 # 验证集完全缓存 ) # 数据加载器 - 减少num_workers节省内存 train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=2, # 减少worker数量 pin_memory=True ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, collate_fn=list_data_collate, num_workers=1, # 减少worker数量 pin_memory=True ) # =============== 加载文本特征 =============== # 内存优化:使用内存映射加载大文件 clip_feats = np.load("./clip_text_features.npy", mmap_mode='r') clip_feats_tensor = torch.from_numpy(np.array(clip_feats)).float().to(device) def get_text_features(bs): """内存友好的文本特征获取""" idx = torch.randint(0, len(clip_feats), (bs,)) # 使用索引直接从内存映射中获取 return torch.tensor(clip_feats[idx]).float().to(device) # =============== 融合模块定义 =============== class MemoryEfficientCrossAttention(nn.Module): """内存优化的交叉注意力模块""" def __init__(self, img_dim=192, text_dim=512, num_heads=4): super().__init__() self.num_heads = num_heads self.head_dim = img_dim // num_heads # 使用更小的线性层 self.qkv = nn.Linear(img_dim, img_dim * 3, bias=False) self.text_proj = nn.Linear(text_dim, img_dim) self.out = nn.Linear(img_dim, img_dim) def forward(self, img_feat, text_feat): B, C, D, H, W = img_feat.shape N = D * H * W img_flat = img_feat.view(B, C, N).permute(0, 2, 1) # (B, N, C) # 多头注意力机制 qkv = self.qkv(img_flat).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # (B, num_heads, N, head_dim) # 文本特征处理 text_feat = self.text_proj(text_feat).view(B, 1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, num_heads, 1, head_dim) # 注意力计算 - 使用缩放点积 attn = torch.matmul(q, text_feat.transpose(-2, -1)) / (self.head_dim ** 0.5) attn = torch.softmax(attn, dim=-1) # 上下文向量 context = torch.matmul(attn, v) # (B, num_heads, N, head_dim) context = context.transpose(1, 2).contiguous().view(B, N, C) # (B, N, C) # 输出投影 out = self.out(context).permute(0, 2, 1).view(B, C, D, H, W) return img_feat + out # =============== 主模型定义 =============== class EfficientSwinUNETR(SwinUNETR): """内存优化的SwinUNETR变体""" def __init__(self, img_size, in_channels, out_channels, feature_size=12, text_feat_dim=512): super().__init__( img_size=img_size, in_channels=in_channels, out_channels=out_channels, feature_size=feature_size, depths=(2, 2, 2, 2), # 减少层数节省内存 num_heads=(3, 6, 12, 24) # 减少头数 ) # 添加多尺度融合模块 self.fusion_low = MemoryEfficientCrossAttention(img_dim=feature_size*4, text_dim=text_feat_dim) self.fusion_mid = MemoryEfficientCrossAttention(img_dim=feature_size*8, text_dim=text_feat_dim) self.fusion_high = MemoryEfficientCrossAttention(img_dim=feature_size*16, text_dim=text_feat_dim) # 深度监督输出 self.aux_out1 = nn.Conv3d(feature_size*8, out_channels, kernel_size=1) self.aux_out2 = nn.Conv3d(feature_size*4, out_channels, kernel_size=1) def forward(self, x, text_feat=None): # 获取编码器输出 enc_out = self.swinViT(x) # [x0, x1, x2, x3, x4] # 多尺度融合 if text_feat is not None: if text_feat.dim() == 1: text_feat = text_feat.unsqueeze(0) enc_out[2] = self.fusion_low(enc_out[2], text_feat) # 低层特征融合 enc_out[3] = self.fusion_mid(enc_out[3], text_feat) # 中层特征融合 enc_out[4] = self.fusion_high(enc_out[4], text_feat) # 高层特征融合 # 原始解码器 dec_out = super().forward(x) # 深度监督输出 aux1 = self.aux_out1(enc_out[3]) aux2 = self.aux_out2(enc_out[2]) # 上采样辅助输出到原始尺寸 aux1 = F.interpolate(aux1, size=x.shape[2:], mode='trilinear', align_corners=False) aux2 = F.interpolate(aux2, size=x.shape[2:], mode='trilinear', align_corners=False) return dec_out, aux1, aux2 # =============== 模型训练相关 =============== # 初始化模型 model = EfficientSwinUNETR( img_size=crop_size, in_channels=1, out_channels=num_classes, feature_size=10, # 减少特征大小节省内存 text_feat_dim=clip_dim ).to(device) # 内存优化:梯度检查点 - 减少内存峰值 for module in model.modules(): if hasattr(module, 'set_grad_checkpointing'): module.set_grad_checkpointing(True) # 混合损失函数 class CombinedLoss(nn.Module): """组合Dice、交叉熵和Focal损失""" def __init__(self, weights=[0.7, 0.2, 0.1]): super().__init__() self.dice_ce = DiceCELoss( to_onehot_y=True, softmax=True, include_background=True, weight=torch.tensor([0.2, 0.3, 0.5]).to(device) ) self.focal = FocalLoss(to_onehot_y=True, gamma=2.0) self.weights = weights def forward(self, outputs, target): main_out, aux1, aux2 = outputs # 主输出损失 loss_main = self.dice_ce(main_out, target) + self.focal(main_out, target) # 辅助输出损失 loss_aux1 = self.dice_ce(aux1, target) + self.focal(aux1, target) loss_aux2 = self.dice_ce(aux2, target) + self.focal(aux2, target) # 加权组合 total_loss = ( self.weights[0] * loss_main + self.weights[1] * loss_aux1 + self.weights[2] * loss_aux2 ) return total_loss loss_fn = CombinedLoss().to(device) # 优化器和学习率调度 optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=1e-5, betas=(0.9, 0.98) # 调整beta减少内存波动 ) scheduler = CosineAnnealingWarmRestarts( optimizer, T_0=20, # 每20个epoch重置一次 T_mult=1, # 保持周期不变 eta_min=1e-6 ) # 评估相关 post_pred = Compose([ Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=num_classes) ]) post_label = Compose([ AsDiscrete(to_onehot=num_classes) ]) dice_metric = DiceMetric( include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes ) scaler = GradScaler(enabled=use_amp) # 训练状态跟踪 best_dice = -1 best_epoch = 0 no_improve_counter = 0 patience = 12 # 12个epoch无改进则停止 os.makedirs("optimized_checkpoints", exist_ok=True) # =============== 内存友好的训练循环 =============== for epoch in range(1, max_epochs + 1): print(f"\nEpoch {epoch}/{max_epochs}") model.train() epoch_loss = 0 optimizer.zero_grad() # 训练阶段 - 使用梯度累积 for step, batch in enumerate(tqdm(train_loader, desc="Train")): images = batch["image"].to(device, non_blocking=True) labels = batch["label"].to(device, non_blocking=True) text_feat = get_text_features(images.shape[0]) with autocast(enabled=use_amp): outputs = model(images, text_feat) loss = loss_fn(outputs, labels) loss = loss / accumulation_steps # 梯度累积缩放损失 # 反向传播 scaler.scale(loss).backward() # 梯度累积:每accumulation_steps步更新一次 if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_loader): scaler.step(optimizer) scaler.update() optimizer.zero_grad() # 手动内存清理 if step % 10 == 0: torch.cuda.empty_cache() gc.collect() epoch_loss += loss.item() * accumulation_steps # 计算平均训练损失 avg_train_loss = epoch_loss / len(train_loader) current_lr = optimizer.param_groups[0]['lr'] print(f"Train Loss: {avg_train_loss:.4f} | LR: {current_lr:.2e}") # 更新学习率 scheduler.step() # 验证阶段 model.eval() val_dices = [] with torch.no_grad(): for batch in tqdm(val_loader, desc="Val"): images = batch["image"].to(device, non_blocking=True) labels = batch["label"].to(device, non_blocking=True) text_feat = get_text_features(images.shape[0]) with autocast(enabled=use_amp): outputs, _, _ = model(images, text_feat) # 只使用主输出 # 后处理和指标计算 outputs_list = decollate_batch(outputs) labels_list = decollate_batch(labels) outputs_convert = [post_pred(o) for o in outputs_list] labels_convert = [post_label(l) for l in labels_list] dice_metric(y_pred=outputs_convert, y=labels_convert) val_dices.append(dice_metric.aggregate().item()) dice_metric.reset() # 手动内存清理 torch.cuda.empty_cache() gc.collect() avg_dice = np.mean(val_dices) print(f"Val Dice: {avg_dice:.4f}") # 早停机制和模型保存 if avg_dice > best_dice: best_dice = avg_dice best_epoch = epoch no_improve_counter = 0 torch.save( model.state_dict(), f"optimized_checkpoints/best_model_epoch{epoch}_dice{avg_dice:.4f}.pth" ) print(f"✅ 保存最佳模型 @ epoch {epoch} | Dice: {avg_dice:.4f}") else: no_improve_counter += 1 print(f"⏳ 未改进次数: {no_improve_counter}/{patience}") if no_improve_counter >= patience: print(f"🛑 早停触发! 最佳Dice: {best_dice:.4f} @ epoch {best_epoch}") break # 定期保存检查点但限制数量 if epoch % 10 == 0: # 只保留最新的3个检查点 checkpoint_files = glob("optimized_checkpoints/checkpoint_*.pth") checkpoint_files.sort(key=os.path.getmtime) for old_checkpoint in checkpoint_files[:-3]: os.remove(old_checkpoint) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_train_loss, 'dice': avg_dice }, f"optimized_checkpoints/checkpoint_epoch{epoch}.pth") # 每5个epoch进行一次完整内存清理 if epoch % 5 == 0: torch.cuda.empty_cache() gc.collect() print("训练完成!")这份代码报错啦(covid_seg) (base) liulicheng@ailab-MS-7B79:~/MultiModal_MedSeg_2025$ /home/liulicheng/anaconda3/envs/covid_seg/bin/python /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clip_multiscale_fusion.py 使用尺寸: crop=(64, 64, 32) Traceback (most recent call last): File "/home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clip_multiscale_fusion.py", line 108, in <module> train_ds = PersistentDataset( TypeError: __init__() got an unexpected keyword argument 'cache_rate'
最新发布
07-01
<think>我们遇到了一个错误:`TypeError:__init__()got anunexpected keywordargument 'cache_rate'`。这是因为在MONAI的`PersistentDataset`中,`cache_rate`参数可能不存在。根据MONAI的文档,`PersistentDataset`的初始化参数包括`data`、`transform`、`cache_dir`和`hash_func`等,但没有`cache_rate`参数。`PersistentDataset`会缓存所有数据到磁盘,因此不需要`cache_rate`参数。解决方案:1.移除`cache_rate`参数,因为`PersistentDataset`默认会缓存所有数据。2.如果我们只想缓存部分数据,可以考虑使用`CacheDataset`并设置`cache_rate`,但`CacheDataset`会将数据缓存在内存中,而`PersistentDataset`是缓存在磁盘上。根据我们的内存优化需求,我们可能希望使用磁盘缓存,因此可以接受缓存所有数据(因为磁盘空间通常比内存大)。但是,如果我们确实需要限制缓存的数据量,我们可以考虑在创建数据集之前对数据进行采样(例如,只取部分数据),但这样会减少训练数据量。考虑到我们已经在使用`PersistentDataset`来减少内存压力,并且我们希望尽可能使用所有数据,我们可以移除`cache_rate`参数,让所有数据都被缓存到磁盘(但注意,验证集我们设置`cache_rate=1.0`,同样需要移除)。修改如下:-对于`train_ds`和`val_ds`,移除`cache_rate`参数。但是,我们注意到在`PersistentDataset`中,我们原本设置`cache_rate=0.6`,目的是只缓存60%的数据。由于`PersistentDataset`不支持`cache_rate`,我们可以考虑以下替代方案:方案1:使用`CacheDataset`,但这样会占用内存,不符合我们的内存优化目标。方案2:继续使用`PersistentDataset`并缓存所有数据(因为磁盘缓存不会占用太多内存,只是加载时间可能会增加)。我们可以接受,因为我们的内存优化重点是减少GPU内存。因此,我们选择移除`cache_rate`参数,让`PersistentDataset`缓存所有数据。修改后的代码: ```pythontrain_ds =PersistentDataset(train_files,transform=train_transforms,cache_dir="./cache/train") val_ds= PersistentDataset(val_files,transform=val_transforms,cache_dir="./cache/val" )```另外,我们还需要确保缓存目录存在,代码中已经使用`os.makedirs`创建了目录。此外,我们还需要注意另一个潜在问题:在`get_text_features`函数中,我们使用`torch.randint`生成随机索引,然后从`clip_feats`(一个内存映射数组)中获取数据。但是,我们使用`np.array(clip_feats[idx])`将切片转换为数组,这可能会导致整个数组被加载到内存中(因为内存映射数组的索引会返回实际数据,而使用`np.array`会创建一个新的数组)。为了避免这个问题,我们可以直接使用`torch.tensor(clip_feats[idx])`,因为`clip_feats[idx]`已经是一个numpy数组。但是,我们之前已经将整个`clip_feats`通过`np.array`转换了,这实际上已经将整个数组加载到内存中(因为我们在创建`clip_feats_tensor`时使用了`np.array(clip_feats)`)。这与我们使用内存映射的初衷相违背。因此,我们需要调整文本特征的加载方式。修正文本特征加载:-我们不应该一开始就将整个文本特征加载到内存中(即使使用内存映射,`np.array(clip_feats)`也会将整个数组加载到内存)。-我们可以保留内存映射数组,然后在`get_text_features`函数中直接索引这个内存映射数组,并转换为张量。修改如下:-去掉`clip_feats_tensor`,因为不需要提前全部加载。-修改`clip_feats`为内存映射数组,并在`get_text_features`中直接使用它。初始加载文本特征:```python#使用内存映射加载大文件,不立即转换为数组clip_feats =np.load("./clip_text_features.npy", mmap_mode='r')```在`get_text_features`中:```pythondef get_text_features(bs):idx =torch.randint(0, len(clip_feats), (bs,))#从内存映射数组中获取指定索引的数据#注意:clip_feats[idx]返回一个numpy数组,然后我们将其转换为张量returntorch.tensor(clip_feats[idx]).float().to(device)```这样,我们只在需要时加载小批量的文本特征,而不是整个文件。另外,我们还需要注意,在训练循环中,我们使用了`text_feat= get_text_features(images.shape[0])`,这会导致每个批次都从磁盘读取文本特征。由于文本特征文件不大(假设只有几百个样本),我们可以考虑在训练开始前将整个文本特征加载到内存中,以避免频繁的磁盘读取。但是,如果文件很大(例如几GB),则应该使用内存映射按需读取。这里我们假设文本特征文件不大,所以可以选择在开始时全部加载到内存中(如果文件小),或者使用内存映射(如果文件大)。根据问题描述,我们之前已经尝试使用内存映射,所以这里我们继续使用内存映射方式。但是,如果文件不大,我们可以这样加载:```pythonclip_feats= np.load("./clip_text_features.npy")clip_feats= torch.from_numpy(clip_feats).float().to(device)```然后在`get_text_features`中:```pythondef get_text_features(bs):idx =torch.randint(0, len(clip_feats), (bs,))return clip_feats[idx]#直接返回设备上的张量```这样效率更高。因此,我们需要根据文本特征文件的大小来决定。如果文件很大(比如超过1GB),则使用内存映射方式按需加载;如果文件小,则一次性加载到内存。由于我们不知道具体大小,但文本特征通常是每个样本一个向量,假设有1000个样本,每个样本512维,那么总大小是1000*512*4(float32)≈2MB,所以完全可以一次性加载到内存。因此,我们修改为一次性加载:```python#加载文本特征-一次性加载到内存clip_feats= np.load("./clip_text_features.npy")clip_feats =torch.from_numpy(clip_feats).float().to(device)def get_text_features(bs):idx= torch.randint(0,len(clip_feats),(bs,))returnclip_feats[idx] ```这样既简单又高效。综上所述,我们进行以下修改:1.移除`PersistentDataset`的`cache_rate`参数。2.修改文本特征加载方式:一次性加载整个文件到内存的张量中。修改后的代码部分如下:数据集的创建: ```pythontrain_ds =PersistentDataset(train_files,transform=train_transforms,cache_dir="./cache/train") val_ds= PersistentDataset(val_files,transform=val_transforms,cache_dir="./cache/val" )```文本特征加载:```python#加载文本特征clip_feats= np.load("./clip_text_features.npy")clip_feats =torch.from_numpy(clip_feats).float().to(device) defget_text_features(bs):idx= torch.randint(0,len(clip_feats),(bs,))returnclip_feats[idx] ```我们重新运行代码,应该可以解决这个错误。注意:如果还有其他错误,我们可能需要进一步调试。但根据报错,当前的主要问题是`cache_rate`参数。</think>这个错误是由于MONAI的`PersistentDataset`类不再支持`cache_rate`参数引起的。根据MONAI的最新文档,`PersistentDataset`的初始化参数已发生变化。让我修复这个问题并提供完整的解决方案: ### 错误原因分析 `TypeError: __init__() got an unexpected keyword argument 'cache_rate'` 表明`PersistentDataset`类不再接受`cache_rate`参数。在MONAI的更新版本中,缓存机制发生了变化。 ### 解决方案 我们需要修改数据集创建方式,移除`cache_rate`参数,并使用MONAI推荐的缓存策略: ```python # ===================== 修改后的数据集创建 ===================== # 移除cache_rate参数,使用默认缓存策略 train_ds = PersistentDataset( train_files, transform=train_transforms, cache_dir="./cache/train" ) val_ds = PersistentDataset( val_files, transform=val_transforms, cache_dir="./cache/val" ) ``` ### 完整修复代码 以下是完整的修复后的代码,包含所有必要的修改: ```python import os import gc import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from glob import glob from tqdm import tqdm from sklearn.model_selection import train_test_split from torch.cuda.amp import GradScaler, autocast from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd, RandRotate90d, EnsureTyped, Resized, RandZoomd, RandGaussianNoised, CenterSpatialCropd, Activations, AsDiscrete, RandCoarseDropoutd, RandBiasFieldd ) from monai.data import PersistentDataset, list_data_collate, decollate_batch from monai.networks.nets import SwinUNETR from monai.metrics import DiceMetric from monai.losses import DiceCELoss, FocalLoss # ================ 内存优化配置 ================ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' torch.backends.cudnn.benchmark = True # ========================== 参数配置 ========================== root_dir = "datasets/LiTS/processed" images = sorted(glob(os.path.join(root_dir, "images", "*.nii.gz"))) labels = sorted(glob(os.path.join(root_dir, "labels", "*.nii.gz"))) data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)] train_files, val_files = train_test_split(data, test_size=0.15, random_state=42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练参数 max_epochs = 200 batch_size = 1 num_classes = 3 learning_rate = 1e-4 clip_dim = 512 use_amp = True accumulation_steps = 4 # 图像尺寸 base_size = (96, 96, 48) crop_size = (64, 64, 32) print(f"使用尺寸: crop={crop_size}") # ===================== 内存友好的数据预处理 ===================== train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1.0, neg=1.0, num_samples=1 ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.1, mode=("trilinear", "nearest")), RandGaussianNoised(keys=["image"], prob=0.2, mean=0.0, std=0.05), RandCoarseDropoutd( keys=["image"], holes=5, spatial_size=(10, 10, 5), max_holes=8, prob=0.2, fill_value=0 ), RandBiasFieldd( keys=["image"], coeff_range=(0.05, 0.15), prob=0.1 ), EnsureTyped(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), CenterSpatialCropd(keys=["image", "label"], roi_size=crop_size), EnsureTyped(keys=["image", "label"]), ]) # 创建缓存目录 os.makedirs("./cache/train", exist_ok=True) os.makedirs("./cache/val", exist_ok=True) # 修复:移除cache_rate参数 train_ds = PersistentDataset( train_files, transform=train_transforms, cache_dir="./cache/train" ) val_ds = PersistentDataset( val_files, transform=val_transforms, cache_dir="./cache/val" ) train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=2, pin_memory=True ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, collate_fn=list_data_collate, num_workers=1, pin_memory=True ) # =============== 加载文本特征 =============== clip_feats = np.load("./clip_text_features.npy") clip_feats_tensor = torch.from_numpy(clip_feats).float().to(device) def get_text_features(bs): """内存友好的文本特征获取""" idx = torch.randint(0, len(clip_feats), (bs,)) return clip_feats_tensor[idx] # =============== 融合模块定义 =============== class MemoryEfficientCrossAttention(nn.Module): """内存优化的交叉注意力模块""" def __init__(self, img_dim=192, text_dim=512, num_heads=4): super().__init__() self.num_heads = num_heads self.head_dim = img_dim // num_heads # 使用更小的线性层 self.qkv = nn.Linear(img_dim, img_dim * 3, bias=False) self.text_proj = nn.Linear(text_dim, img_dim) self.out = nn.Linear(img_dim, img_dim) def forward(self, img_feat, text_feat): B, C, D, H, W = img_feat.shape N = D * H * W img_flat = img_feat.view(B, C, N).permute(0, 2, 1) # (B, N, C) # 多头注意力机制 qkv = self.qkv(img_flat).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # (B, num_heads, N, head_dim) # 文本特征处理 text_feat = self.text_proj(text_feat).view(B, 1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, num_heads, 1, head_dim) # 注意力计算 attn = torch.matmul(q, text_feat.transpose(-2, -1)) / (self.head_dim ** 0.5) attn = torch.softmax(attn, dim=-1) # 上下文向量 context = torch.matmul(attn, v) # (B, num_heads, N, head_dim) context = context.transpose(1, 2).contiguous().view(B, N, C) # (B, N, C) # 输出投影 out = self.out(context).permute(0, 2, 1).view(B, C, D, H, W) return img_feat + out # =============== 主模型定义 =============== class EfficientSwinUNETR(SwinUNETR): """内存优化的SwinUNETR变体""" def __init__(self, img_size, in_channels, out_channels, feature_size=12, text_feat_dim=512): super().__init__( img_size=img_size, in_channels=in_channels, out_channels=out_channels, feature_size=feature_size, depths=(2, 2, 2, 2), # 减少层数节省内存 num_heads=(3, 6, 12, 24) # 减少头数 ) # 添加多尺度融合模块 self.fusion_low = MemoryEfficientCrossAttention(img_dim=feature_size*4, text_dim=text_feat_dim) self.fusion_mid = MemoryEfficientCrossAttention(img_dim=feature_size*8, text_dim=text_feat_dim) self.fusion_high = MemoryEfficientCrossAttention(img_dim=feature_size*16, text_dim=text_feat_dim) # 深度监督输出 self.aux_out1 = nn.Conv3d(feature_size*8, out_channels, kernel_size=1) self.aux_out2 = nn.Conv3d(feature_size*4, out_channels, kernel_size=1) def forward(self, x, text_feat=None): # 获取编码器输出 enc_out = self.swinViT(x) # [x0, x1, x2, x3, x4] # 多尺度融合 if text_feat is not None: if text_feat.dim() == 1: text_feat = text_feat.unsqueeze(0) enc_out[2] = self.fusion_low(enc_out[2], text_feat) # 低层特征融合 enc_out[3] = self.fusion_mid(enc_out[3], text_feat) # 中层特征融合 enc_out[4] = self.fusion_high(enc_out[4], text_feat) # 高层特征融合 # 原始解码器 dec_out = super().forward(x) # 深度监督输出 aux1 = self.aux_out1(enc_out[3]) aux2 = self.aux_out2(enc_out[2]) # 上采样辅助输出到原始尺寸 aux1 = F.interpolate(aux1, size=x.shape[2:], mode='trilinear', align_corners=False) aux2 = F.interpolate(aux2, size=x.shape[2:], mode='trilinear', align_corners=False) return dec_out, aux1, aux2 # =============== 模型训练相关 =============== # 初始化模型 model = EfficientSwinUNETR( img_size=crop_size, in_channels=1, out_channels=num_classes, feature_size=10, # 减少特征大小节省内存 text_feat_dim=clip_dim ).to(device) # 内存优化:梯度检查点 for module in model.modules(): if hasattr(module, 'set_grad_checkpointing'): module.set_grad_checkpointing(True) # 混合损失函数 class CombinedLoss(nn.Module): """组合Dice、交叉熵和Focal损失""" def __init__(self, weights=[0.7, 0.2, 0.1]): super().__init__() self.dice_ce = DiceCELoss( to_onehot_y=True, softmax=True, include_background=True, weight=torch.tensor([0.2, 0.3, 0.5]).to(device) ) self.focal = FocalLoss(to_onehot_y=True, gamma=2.0) self.weights = weights def forward(self, outputs, target): main_out, aux1, aux2 = outputs # 主输出损失 loss_main = self.dice_ce(main_out, target) + self.focal(main_out, target) # 辅助输出损失 loss_aux1 = self.dice_ce(aux1, target) + self.focal(aux1, target) loss_aux2 = self.dice_ce(aux2, target) + self.focal(aux2, target) # 加权组合 total_loss = ( self.weights[0] * loss_main + self.weights[1] * loss_aux1 + self.weights[2] * loss_aux2 ) return total_loss loss_fn = CombinedLoss().to(device) # 优化器和学习率调度 optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=1e-5, betas=(0.9, 0.98) ) scheduler = CosineAnnealingWarmRestarts( optimizer, T_0=20, T_mult=1, eta_min=1e-6 ) # 评估相关 post_pred = Compose([ Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=num_classes) ]) post_label = Compose([ AsDiscrete(to_onehot=num_classes) ]) dice_metric = DiceMetric( include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes ) scaler = GradScaler(enabled=use_amp) # 训练状态跟踪 best_dice = -1 best_epoch = 0 no_improve_counter = 0 patience = 12 os.makedirs("optimized_checkpoints", exist_ok=True) # =============== 内存友好的训练循环 =============== for epoch in range(1, max_epochs + 1): print(f"\nEpoch {epoch}/{max_epochs}") model.train() epoch_loss = 0 optimizer.zero_grad() # 训练阶段 for step, batch in enumerate(tqdm(train_loader, desc="Train")): images = batch["image"].to(device, non_blocking=True) labels = batch["label"].to(device, non_blocking=True) text_feat = get_text_features(images.shape[0]) with autocast(enabled=use_amp): outputs = model(images, text_feat) loss = loss_fn(outputs, labels) loss = loss / accumulation_steps # 反向传播 scaler.scale(loss).backward() # 梯度累积 if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_loader): scaler.step(optimizer) scaler.update() optimizer.zero_grad() # 手动内存清理 if step % 10 == 0: torch.cuda.empty_cache() gc.collect() epoch_loss += loss.item() * accumulation_steps # 计算平均训练损失 avg_train_loss = epoch_loss / len(train_loader) current_lr = optimizer.param_groups[0]['lr'] print(f"Train Loss: {avg_train_loss:.4f} | LR: {current_lr:.2e}") # 更新学习率 scheduler.step() # 验证阶段 model.eval() val_dices = [] with torch.no_grad(): for batch in tqdm(val_loader, desc="Val"): images = batch["image"].to(device, non_blocking=True) labels = batch["label"].to(device, non_blocking=True) text_feat = get_text_features(images.shape[0]) with autocast(enabled=use_amp): outputs, _, _ = model(images, text_feat) # 只使用主输出 # 后处理和指标计算 outputs_list = decollate_batch(outputs) labels_list = decollate_batch(labels) outputs_convert = [post_pred(o) for o in outputs_list] labels_convert = [post_label(l) for l in labels_list] dice_metric(y_pred=outputs_convert, y=labels_convert) val_dices.append(dice_metric.aggregate().item()) dice_metric.reset() # 手动内存清理 torch.cuda.empty_cache() gc.collect() avg_dice = np.mean(val_dices) print(f"Val Dice: {avg_dice:.4f}") # 早停机制和模型保存 if avg_dice > best_dice: best_dice = avg_dice best_epoch = epoch no_improve_counter = 0 torch.save( model.state_dict(), f"optimized_checkpoints/best_model_epoch{epoch}_dice{avg_dice:.4f}.pth" ) print(f"✅ 保存最佳模型 @ epoch {epoch} | Dice: {avg_dice:.4f}") else: no_improve_counter += 1 print(f"⏳ 未改进次数: {no_improve_counter}/{patience}") if no_improve_counter >= patience: print(f"🛑 早停触发! 最佳Dice: {best_dice:.4f} @ epoch {best_epoch}") break # 定期保存检查点 if epoch % 10 == 0: # 只保留最新的3个检查点 checkpoint_files = glob("optimized_checkpoints/checkpoint_*.pth") checkpoint_files.sort(key=os.path.getmtime) for old_checkpoint in checkpoint_files[:-3]: os.remove(old_checkpoint) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_train_loss, 'dice': avg_dice }, f"optimized_checkpoints/checkpoint_epoch{epoch}.pth") # 每5个epoch进行一次完整内存清理 if epoch % 5 == 0: torch.cuda.empty_cache() gc.collect() print("训练完成!") ``` ### 关键修改解释 1. **数据集创建修复**: - 移除了`PersistentDataset`中的`cache_rate`参数 - 使用MONAI默认的缓存策略,所有数据都会被缓存到指定目录 2. **文本特征加载优化**: - 简化了文本特征加载方式,直接加载整个文件到内存 - 避免了不必要的内存映射操作 3. **内存管理增强**: - 保留了梯度累积策略,减少显存使用 - 定期执行显存清理和垃圾回收 - 使用更小的模型尺寸和特征维度 ### 潜在问题预防 1. **缓存目录权限**: - 确保`./cache/train`和`./cache/val`目录有写入权限 - 如果遇到权限问题,可以修改为绝对路径如`/tmp/cache/train` 2. **文本特征文件路径**: - 确保`./clip_text_features.npy`文件存在且路径正确 - 如果是相对路径,确保脚本在正确的工作目录运行 3. **MONAI版本兼容性**: - 建议使用MONAI v0.9.0或更高版本 - 可以通过`pip install monai --upgrade`更新 这些修改应该能解决`cache_rate`参数错误,同时保持原有的内存优化策略。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值