pytorch:参数pin_memory=True和non_blocking=True的作用

本文介绍如何通过设置pin_memory和non_blocking参数来提高PyTorch数据加载和传输效率,实现训练过程加速。pin_memory用于配置数据是否存放于锁页内存以保持与GPU的高速数据交换,non_blocking则确保数据在GPU内存中不被释放,两者结合使用可以显著提升模型训练的速度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、pin_memory

pin_memory是dataloader()的参数,默认值为False,其作用是是否把把数据存放在锁页内存中。主机的内存根据物理内存(内存条)与虚拟内存(硬盘)进行数据交换分为锁页内存和不锁页内存:

锁页内存:数据存放在物理内存上(内存条)上;
不锁页内存:当物理内存(内存条)满载时,把部分数据转换到虚拟内存上(硬盘)上。
锁页内存(pin_memory)能够保持与GPU进行高速传输,在训练时加快数据的读取,从而加快训练速度。因此,如果主机/服务器的内存足够大,建议把pin_memory设为True,如:

trainloader = torch.utils.data.DataLoader(dataset=traindata, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True)

二、non_blocking

non_blocking时cuda()的参数,默认值为False,其作用和pin_memory一样,pin_memory是针对物理内存(内存条),而non_blocking是针对GPU上的内存(显存),表士把数据锁页在显存上,在后台进程过程中不释放。一般地,如果pin_momery为True,把non_blocking也设为True,有助于加速数据传输,加快训练过程,如:

model = Model().cuda(non_blocking=True)
# ====================== # UNETR 训练脚本 # ====================== import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd, RandRotate90d, EnsureTyped, Activations, AsDiscrete, Resized, RandZoomd, RandGaussianNoised, CenterSpatialCropd ) from monai.data import list_data_collate, Dataset # 使用普通Dataset from monai.networks.nets import UNETR from monai.losses import DiceCELoss from monai.metrics import DiceMetric from glob import glob from sklearn.model_selection import train_test_split from torch.optim.lr_scheduler import LambdaLR from tqdm import tqdm from torch.cuda.amp import GradScaler, autocast import matplotlib.pyplot as plt import gc import nibabel as nib import sys import monai # 自定义Transform:用于把RandCropByPosNegLabeld返回的list转成Tensor class ExtractFirstSampledDict(monai.transforms.Transform): def __call__(self, data): out = {} for k, v in data.items(): if isinstance(v, list) and len(v) == 1: out[k] = v[0] else: out[k] = v return out # ====================== # 配置参数 # ====================== root_dir = "datasets/LiTS/processed" images_dir = os.path.join(root_dir, "images") labels_dir = os.path.join(root_dir, "labels") max_epochs = 200 batch_size = 1 learning_rate = 1e-4 num_classes = 3 warmup_epochs = 10 use_amp = False # AMP 对 UNETR 不稳定,建议关闭 # 禁用MetaTensor以避免decollate错误 os.environ["MONAI_USE_META_DICT"] = "0" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 打印环境信息 print("===== 环境信息 =====") print(f"Python版本: {sys.version}") print(f"PyTorch版本: {torch.__version__}") print(f"MONAI版本: {monai.__version__}") print(f"nibabel版本: {nib.__version__}") if torch.cuda.is_available(): print(f"CUDA版本: {torch.version.cuda}") print(f"cuDNN版本: {torch.backends.cudnn.version()}") # 尺寸设置 - 确保能被16整除 def get_valid_size(size, divisor=16): return tuple([max(divisor, (s // divisor) * divisor) for s in size]) base_size = (128, 128, 64) resized_size = get_valid_size(base_size) crop_size = get_valid_size((64, 64, 64)) # 减小尺寸以节省显存 print(f"输入尺寸: resized_size={resized_size}, crop_size={crop_size}") # ====================== # 数据预处理 # ====================== train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1.0, neg=1.0, num_samples=1, image_threshold=0 ), ExtractFirstSampledDict(), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.9, max_zoom=1.1, mode=("trilinear", "nearest")), RandGaussianNoised(keys=["image"], prob=0.2, mean=0.0, std=0.05), EnsureTyped(keys=["image", "label"], data_type="tensor"), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")), CenterSpatialCropd(keys=["image", "label"], roi_size=crop_size), EnsureTyped(keys=["image", "label"], data_type="tensor"), ]) images = sorted(glob(os.path.join(images_dir, "*.nii.gz"))) labels = sorted(glob(os.path.join(labels_dir, "*.nii.gz"))) data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)] train_files, val_files = train_test_split(data, test_size=0.2, random_state=42) train_ds = Dataset(data=train_files, transform=train_transforms) val_ds = Dataset(data=val_files, transform=val_transforms) train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=0, # 避免多进程导致的问题 collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=0, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available() ) # ====================== # 模型构建 # ====================== model = UNETR( in_channels=1, out_channels=num_classes, img_size=crop_size, feature_size=16, hidden_size=512, mlp_dim=2048, num_heads=8, pos_embed="perceptron", norm_name="batch", res_block=True, dropout_rate=0.1 ).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"模型参数总数: {total_params / 1e6:.2f}M") # ====================== # 损失 + 优化器 # ====================== class_weights = torch.tensor([0.2, 0.3, 0.5]).to(device) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, ce_weight=class_weights, lambda_dice=0.5, lambda_ce=0.5) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5) def lr_lambda(epoch): if epoch < warmup_epochs: return (epoch + 1) / warmup_epochs progress = (epoch - warmup_epochs) / (max_epochs - warmup_epochs) return 0.5 * (1 + np.cos(np.pi * progress)) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) # ====================== # 评估器 # ====================== post_pred = Compose([Activations(softmax=True), AsDiscrete(argmax=True)]) post_label = Compose([AsDiscrete(to_onehot=num_classes)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes) scaler = GradScaler(enabled=use_amp) # ====================== # 训练循环 # ====================== best_metric = -1 best_metric_epoch = -1 train_loss_history = [] val_dice_history = [] os.makedirs("unetr_checkpoints", exist_ok=True) os.makedirs("logs", exist_ok=True) print("\n===== 测试数据加载 =====") try: test_sample = train_ds[0] print("数据加载测试成功!") print(f"图像形状: {test_sample['image'].shape}") print(f"标签形状: {test_sample['label'].shape}") except Exception as e: print(f"数据加载失败: {str(e)}") print("\n尝试替代加载方式...") from monai.data import NibabelReader sample_file = train_files[0] reader = NibabelReader() img = reader.read(sample_file['image']) label = reader.read(sample_file['label']) print(f"手动加载成功 - 图像形状: {img.shape}, 标签形状: {label.shape}") for epoch in range(max_epochs): print(f"\nEpoch {epoch+1}/{max_epochs}") model.train() epoch_loss, step = 0, 0 pbar_train = tqdm(total=len(train_loader), desc=f"训练 Epoch {epoch+1}") for batch_data in train_loader: step += 1 try: inputs = batch_data["image"].to(device) labels = batch_data["label"].to(device) optimizer.zero_grad() with autocast(enabled=use_amp): outputs = model(inputs) loss = loss_function(outputs, labels) if use_amp: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss.backward() optimizer.step() epoch_loss += loss.item() pbar_train.update(1) pbar_train.set_postfix({"loss": f"{loss.item():.4f}"}) if step % 10 == 0: torch.cuda.empty_cache() except RuntimeError as e: if 'CUDA out of memory' in str(e): print("\nCUDA内存不足,跳过该批次") torch.cuda.empty_cache() gc.collect() else: print(f"\n训练时发生错误: {str(e)}") continue except Exception as e: print(f"\n训练时发生未知错误: {str(e)}") continue pbar_train.close() epoch_loss /= step train_loss_history.append(epoch_loss) print(f"训练平均损失: {epoch_loss:.4f}") scheduler.step() current_lr = optimizer.param_groups[0]['lr'] print(f"当前学习率: {current_lr:.7f}") model.eval() dice_vals = [] pbar_val = tqdm(total=len(val_loader), desc=f"验证 Epoch {epoch+1}") with torch.no_grad(): for val_data in val_loader: try: val_images = val_data["image"].to(device) val_labels = val_data["label"].to(device) val_outputs = model(val_images) val_preds = post_pred(val_outputs.cpu()) val_truth = post_label(val_labels.cpu()) dice_metric(y_pred=[val_preds], y=[val_truth]) metric = dice_metric.aggregate().item() dice_metric.reset() dice_vals.append(metric) pbar_val.update(1) pbar_val.set_postfix({"dice": f"{metric:.4f}"}) except RuntimeError as e: print(f"\n验证时发生错误: {str(e)}") continue except Exception as e: print(f"\n验证时发生未知错误: {str(e)}") continue pbar_val.close() avg_metric = np.mean(dice_vals) if dice_vals else 0.0 val_dice_history.append(avg_metric) print(f"验证平均Dice: {avg_metric:.4f}") if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), f"unetr_checkpoints/best_model_epoch{best_metric_epoch}_dice{best_metric:.4f}.pth") print(f"保存新的最佳模型! Epoch: {best_metric_epoch}, Dice: {best_metric:.4f}") if (epoch + 1) % 10 == 0: torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss, 'dice': avg_metric }, f"unetr_checkpoints/checkpoint_epoch_{epoch+1}.pth") plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(train_loss_history, label='训练损失') plt.title('训练损失') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.subplot(1, 2, 2) plt.plot(val_dice_history, label='验证Dice', color='orange') plt.title('验证Dice') plt.xlabel('Epoch') plt.ylabel('Dice') plt.legend() plt.tight_layout() plt.savefig("logs/unetr_training_metrics.png") plt.close() torch.cuda.empty_cache() gc.collect() print(f"\n训练完成! 最佳Dice: {best_metric:.4f} at epoch {best_metric_epoch}") 这个代码covid_seg) (base) liulicheng@ailab-MS-7B79:~/MultiModal_MedSeg_2025$ /home/liulicheng/anaconda3/envs/covid_seg/bin/python /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py ===== 环境信息 ===== Python版本: 3.8.12 | packaged by conda-forge | (default, Sep 29 2021, 19:52:28) [GCC 9.4.0] PyTorch版本: 2.1.0+cu118 MONAI版本: 1.3.2 nibabel版本: 5.2.1 CUDA版本: 11.8 cuDNN版本: 8700 输入尺寸: resized_size=(128, 128, 64), crop_size=(64, 64, 64) /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.networks.nets.unetr UNETR.__init__:pos_embed: Argument `pos_embed` has been deprecated since version 1.2. It will be removed in version 1.4. please use `proj_type` instead. warn_deprecated(argname, msg, warning_category) 模型参数总数: 43.67M /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.losses.dice DiceCELoss.__init__:ce_weight: Argument `ce_weight` has been deprecated since version 1.2. It will be removed in version 1.4. please use `weight` instead. warn_deprecated(argname, msg, warning_category) ===== 测试数据加载 ===== 数据加载测试成功! 数据加载失败: list indices must be integers or slices, not str 尝试替代加载方式... 手动加载成功 - 图像形状: (512, 512, 94), 标签形状: (512, 512, 94) Epoch 1/200 训练 Epoch 1: 1%|█ | 1/104 [00:11<19:08, 11.15s/it, loss=1.0541]这样也太慢了吧
06-27
import os import gc import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from glob import glob from tqdm import tqdm from sklearn.model_selection import train_test_split from torch.cuda.amp import GradScaler, autocast from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd, RandRotate90d, EnsureTyped, Resized, RandZoomd, RandGaussianNoised, CenterSpatialCropd, Activations, AsDiscrete, RandCoarseDropoutd, RandBiasFieldd ) from monai.data import PersistentDataset, list_data_collate, decollate_batch from monai.networks.nets import SwinUNETR from monai.metrics import DiceMetric from monai.losses import DiceCELoss, FocalLoss # ================ 内存优化配置 ================ # 设置环境变量减少内存碎片 os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # 启用cudnn基准测试加速但增加内存,根据GPU内存大小选择 torch.backends.cudnn.benchmark = True # 如果内存不足可设为False # ========================== 参数配置 ========================== root_dir = "datasets/LiTS/processed" images = sorted(glob(os.path.join(root_dir, "images", "*.nii.gz"))) labels = sorted(glob(os.path.join(root_dir, "labels", "*.nii.gz"))) data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)] # 内存优化:使用更小的验证集比例 train_files, val_files = train_test_split(data, test_size=0.15, random_state=42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练参数 max_epochs = 200 batch_size = 1 # 3D模型batch_size保持1 num_classes = 3 learning_rate = 1e-4 clip_dim = 512 use_amp = True # 启用混合精度减少内存使用 accumulation_steps = 4 # 梯度累积步数,模拟更大batch size # 图像尺寸 - 调整到更小的尺寸以节省内存 base_size = (96, 96, 48) # 原始(128,128,64) crop_size = (64, 64, 32) # 原始(64,64,32) print(f"使用尺寸: crop={crop_size}") # ===================== 内存友好的数据预处理 ===================== train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), # 移除Resized步骤直接裁剪,减少内存使用 RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1.0, neg=1.0, num_samples=1 ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.1, mode=("trilinear", "nearest")), # 缩小缩放范围 RandGaussianNoised(keys=["image"], prob=0.2, mean=0.0, std=0.05), # 减小噪声幅度 # 添加内存友好的高级增强 RandCoarseDropoutd( keys=["image"], holes=5, # 减少空洞数量 spatial_size=(10, 10, 5), # 减小空洞尺寸 max_holes=8, prob=0.2, fill_value=0 ), RandBiasFieldd( keys=["image"], coeff_range=(0.05, 0.15), # 减小偏置场强度 prob=0.1 ), EnsureTyped(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), CenterSpatialCropd(keys=["image", "label"], roi_size=crop_size), # 直接裁剪 EnsureTyped(keys=["image", "label"]), ]) # 使用PersistentDataset并限制缓存大小 os.makedirs("./cache/train", exist_ok=True) os.makedirs("./cache/val", exist_ok=True) train_ds = PersistentDataset( train_files, transform=train_transforms, cache_dir="./cache/train", cache_rate=0.6 # 只缓存60%的数据以减少内存 ) val_ds = PersistentDataset( val_files, transform=val_transforms, cache_dir="./cache/val", cache_rate=1.0 # 验证集完全缓存 ) # 数据加载器 - 减少num_workers节省内存 train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=2, # 减少worker数量 pin_memory=True ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, collate_fn=list_data_collate, num_workers=1, # 减少worker数量 pin_memory=True ) # =============== 加载文本特征 =============== # 内存优化:使用内存映射加载大文件 clip_feats = np.load("./clip_text_features.npy", mmap_mode='r') clip_feats_tensor = torch.from_numpy(np.array(clip_feats)).float().to(device) def get_text_features(bs): """内存友好的文本特征获取""" idx = torch.randint(0, len(clip_feats), (bs,)) # 使用索引直接从内存映射中获取 return torch.tensor(clip_feats[idx]).float().to(device) # =============== 融合模块定义 =============== class MemoryEfficientCrossAttention(nn.Module): """内存优化的交叉注意力模块""" def __init__(self, img_dim=192, text_dim=512, num_heads=4): super().__init__() self.num_heads = num_heads self.head_dim = img_dim // num_heads # 使用更小的线性层 self.qkv = nn.Linear(img_dim, img_dim * 3, bias=False) self.text_proj = nn.Linear(text_dim, img_dim) self.out = nn.Linear(img_dim, img_dim) def forward(self, img_feat, text_feat): B, C, D, H, W = img_feat.shape N = D * H * W img_flat = img_feat.view(B, C, N).permute(0, 2, 1) # (B, N, C) # 多头注意力机制 qkv = self.qkv(img_flat).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # (B, num_heads, N, head_dim) # 文本特征处理 text_feat = self.text_proj(text_feat).view(B, 1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, num_heads, 1, head_dim) # 注意力计算 - 使用缩放点积 attn = torch.matmul(q, text_feat.transpose(-2, -1)) / (self.head_dim ** 0.5) attn = torch.softmax(attn, dim=-1) # 上下文向量 context = torch.matmul(attn, v) # (B, num_heads, N, head_dim) context = context.transpose(1, 2).contiguous().view(B, N, C) # (B, N, C) # 输出投影 out = self.out(context).permute(0, 2, 1).view(B, C, D, H, W) return img_feat + out # =============== 主模型定义 =============== class EfficientSwinUNETR(SwinUNETR): """内存优化的SwinUNETR变体""" def __init__(self, img_size, in_channels, out_channels, feature_size=12, text_feat_dim=512): super().__init__( img_size=img_size, in_channels=in_channels, out_channels=out_channels, feature_size=feature_size, depths=(2, 2, 2, 2), # 减少层数节省内存 num_heads=(3, 6, 12, 24) # 减少头数 ) # 添加多尺度融合模块 self.fusion_low = MemoryEfficientCrossAttention(img_dim=feature_size*4, text_dim=text_feat_dim) self.fusion_mid = MemoryEfficientCrossAttention(img_dim=feature_size*8, text_dim=text_feat_dim) self.fusion_high = MemoryEfficientCrossAttention(img_dim=feature_size*16, text_dim=text_feat_dim) # 深度监督输出 self.aux_out1 = nn.Conv3d(feature_size*8, out_channels, kernel_size=1) self.aux_out2 = nn.Conv3d(feature_size*4, out_channels, kernel_size=1) def forward(self, x, text_feat=None): # 获取编码器输出 enc_out = self.swinViT(x) # [x0, x1, x2, x3, x4] # 多尺度融合 if text_feat is not None: if text_feat.dim() == 1: text_feat = text_feat.unsqueeze(0) enc_out[2] = self.fusion_low(enc_out[2], text_feat) # 低层特征融合 enc_out[3] = self.fusion_mid(enc_out[3], text_feat) # 中层特征融合 enc_out[4] = self.fusion_high(enc_out[4], text_feat) # 高层特征融合 # 原始解码器 dec_out = super().forward(x) # 深度监督输出 aux1 = self.aux_out1(enc_out[3]) aux2 = self.aux_out2(enc_out[2]) # 上采样辅助输出到原始尺寸 aux1 = F.interpolate(aux1, size=x.shape[2:], mode='trilinear', align_corners=False) aux2 = F.interpolate(aux2, size=x.shape[2:], mode='trilinear', align_corners=False) return dec_out, aux1, aux2 # =============== 模型训练相关 =============== # 初始化模型 model = EfficientSwinUNETR( img_size=crop_size, in_channels=1, out_channels=num_classes, feature_size=10, # 减少特征大小节省内存 text_feat_dim=clip_dim ).to(device) # 内存优化:梯度检查点 - 减少内存峰值 for module in model.modules(): if hasattr(module, 'set_grad_checkpointing'): module.set_grad_checkpointing(True) # 混合损失函数 class CombinedLoss(nn.Module): """组合Dice、交叉熵Focal损失""" def __init__(self, weights=[0.7, 0.2, 0.1]): super().__init__() self.dice_ce = DiceCELoss( to_onehot_y=True, softmax=True, include_background=True, weight=torch.tensor([0.2, 0.3, 0.5]).to(device) ) self.focal = FocalLoss(to_onehot_y=True, gamma=2.0) self.weights = weights def forward(self, outputs, target): main_out, aux1, aux2 = outputs # 主输出损失 loss_main = self.dice_ce(main_out, target) + self.focal(main_out, target) # 辅助输出损失 loss_aux1 = self.dice_ce(aux1, target) + self.focal(aux1, target) loss_aux2 = self.dice_ce(aux2, target) + self.focal(aux2, target) # 加权组合 total_loss = ( self.weights[0] * loss_main + self.weights[1] * loss_aux1 + self.weights[2] * loss_aux2 ) return total_loss loss_fn = CombinedLoss().to(device) # 优化器学习率调度 optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=1e-5, betas=(0.9, 0.98) # 调整beta减少内存波动 ) scheduler = CosineAnnealingWarmRestarts( optimizer, T_0=20, # 每20个epoch重置一次 T_mult=1, # 保持周期不变 eta_min=1e-6 ) # 评估相关 post_pred = Compose([ Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=num_classes) ]) post_label = Compose([ AsDiscrete(to_onehot=num_classes) ]) dice_metric = DiceMetric( include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes ) scaler = GradScaler(enabled=use_amp) # 训练状态跟踪 best_dice = -1 best_epoch = 0 no_improve_counter = 0 patience = 12 # 12个epoch无改进则停止 os.makedirs("optimized_checkpoints", exist_ok=True) # =============== 内存友好的训练循环 =============== for epoch in range(1, max_epochs + 1): print(f"\nEpoch {epoch}/{max_epochs}") model.train() epoch_loss = 0 optimizer.zero_grad() # 训练阶段 - 使用梯度累积 for step, batch in enumerate(tqdm(train_loader, desc="Train")): images = batch["image"].to(device, non_blocking=True) labels = batch["label"].to(device, non_blocking=True) text_feat = get_text_features(images.shape[0]) with autocast(enabled=use_amp): outputs = model(images, text_feat) loss = loss_fn(outputs, labels) loss = loss / accumulation_steps # 梯度累积缩放损失 # 反向传播 scaler.scale(loss).backward() # 梯度累积:每accumulation_steps步更新一次 if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_loader): scaler.step(optimizer) scaler.update() optimizer.zero_grad() # 手动内存清理 if step % 10 == 0: torch.cuda.empty_cache() gc.collect() epoch_loss += loss.item() * accumulation_steps # 计算平均训练损失 avg_train_loss = epoch_loss / len(train_loader) current_lr = optimizer.param_groups[0]['lr'] print(f"Train Loss: {avg_train_loss:.4f} | LR: {current_lr:.2e}") # 更新学习率 scheduler.step() # 验证阶段 model.eval() val_dices = [] with torch.no_grad(): for batch in tqdm(val_loader, desc="Val"): images = batch["image"].to(device, non_blocking=True) labels = batch["label"].to(device, non_blocking=True) text_feat = get_text_features(images.shape[0]) with autocast(enabled=use_amp): outputs, _, _ = model(images, text_feat) # 只使用主输出 # 后处理指标计算 outputs_list = decollate_batch(outputs) labels_list = decollate_batch(labels) outputs_convert = [post_pred(o) for o in outputs_list] labels_convert = [post_label(l) for l in labels_list] dice_metric(y_pred=outputs_convert, y=labels_convert) val_dices.append(dice_metric.aggregate().item()) dice_metric.reset() # 手动内存清理 torch.cuda.empty_cache() gc.collect() avg_dice = np.mean(val_dices) print(f"Val Dice: {avg_dice:.4f}") # 早停机制模型保存 if avg_dice > best_dice: best_dice = avg_dice best_epoch = epoch no_improve_counter = 0 torch.save( model.state_dict(), f"optimized_checkpoints/best_model_epoch{epoch}_dice{avg_dice:.4f}.pth" ) print(f"✅ 保存最佳模型 @ epoch {epoch} | Dice: {avg_dice:.4f}") else: no_improve_counter += 1 print(f"⏳ 未改进次数: {no_improve_counter}/{patience}") if no_improve_counter >= patience: print(f"🛑 早停触发! 最佳Dice: {best_dice:.4f} @ epoch {best_epoch}") break # 定期保存检查点但限制数量 if epoch % 10 == 0: # 只保留最新的3个检查点 checkpoint_files = glob("optimized_checkpoints/checkpoint_*.pth") checkpoint_files.sort(key=os.path.getmtime) for old_checkpoint in checkpoint_files[:-3]: os.remove(old_checkpoint) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_train_loss, 'dice': avg_dice }, f"optimized_checkpoints/checkpoint_epoch{epoch}.pth") # 每5个epoch进行一次完整内存清理 if epoch % 5 == 0: torch.cuda.empty_cache() gc.collect() print("训练完成!")这份代码报错啦(covid_seg) (base) liulicheng@ailab-MS-7B79:~/MultiModal_MedSeg_2025$ /home/liulicheng/anaconda3/envs/covid_seg/bin/python /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clip_multiscale_fusion.py 使用尺寸: crop=(64, 64, 32) Traceback (most recent call last): File "/home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clip_multiscale_fusion.py", line 108, in <module> train_ds = PersistentDataset( TypeError: __init__() got an unexpected keyword argument 'cache_rate'
最新发布
07-01
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值