python之语句mode = 'test' if y is None else 'train'

python之语句mode = ‘test’ if y is None else ‘train’

语句:mode = ‘test’ if y is None else ‘train’
相当于:

if y==None:
	mode='test'
else:
	mode='train'

具体示例如下:

>>> def func(y):
...     if y==None:
...             mode='test'
...     else:
...             mode='train'
...     return mode
... 
>>> y
1
>>> mode=func(y)
>>> mode
'train'
>>> y=None
>>> mode=func(y)
>>> mode
'test'
>>> 

# ====================== # UNETR 训练脚本 # ====================== import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd, EnsureTyped, Activations, AsDiscrete, ResizeWithPadOrCropd, RandZoomd, RandGaussianNoised ) from monai.data import list_data_collate, CacheDataset # 使用CacheDataset from monai.networks.nets import UNETR from monai.losses import DiceCELoss from monai.metrics import DiceMetric from glob import glob from sklearn.model_selection import train_test_split from torch.optim.lr_scheduler import LambdaLR from tqdm import tqdm from torch.cuda.amp import GradScaler, autocast import matplotlib.pyplot as plt import gc import nibabel as nib import sys import monai import time # 自定义Transform:用于把RandCropByPosNegLabeld返回的list转成Tensor class ExtractFirstSampledDict(monai.transforms.Transform): def __call__(self, data): out = {} for k, v in data.items(): if isinstance(v, list) and len(v) == 1: out[k] = v[0] else: out[k] = v return out # ====================== # 配置参数 (优化版) # ====================== root_dir = "datasets/LiTS/processed" images_dir = os.path.join(root_dir, "images") labels_dir = os.path.join(root_dir, "labels") max_epochs = 200 batch_size = 2 # 增大批次大小 learning_rate = 1e-4 num_classes = 3 warmup_epochs = 10 use_amp = False # 启用AMP加速 ✅ accumulation_steps = 2 # 梯度累积步数 # 禁用MetaTensor以避免decollate错误 os.environ["MONAI_USE_META_DICT"] = "0" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 启用cuDNN基准测试 if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True # 打印环境信息 print("===== 环境信息 =====") print(f"Python版本: {sys.version}") print(f"PyTorch版本: {torch.__version__}") print(f"MONAI版本: {monai.__version__}") print(f"nibabel版本: {nib.__version__}") if torch.cuda.is_available(): print(f"CUDA版本: {torch.version.cuda}") print(f"cuDNN版本: {torch.backends.cudnn.version()}") print(f"GPU: {torch.cuda.get_device_name(0)}") # 尺寸设置 - 确保能被16整除 def get_valid_size(size, divisor=16): return tuple([max(divisor, (s // divisor) * divisor) for s in size]) base_size = (128, 128, 64) resized_size = get_valid_size(base_size) crop_size = get_valid_size((48, 48, 48)) # 减小尺寸以节省显存 print(f"输入尺寸: resized_size={resized_size}, crop_size={crop_size}") # ====================== # 数据预处理 (优化版) # ====================== train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), ResizeWithPadOrCropd( # 轻量级缩放 ✅ keys=["image", "label"], spatial_size=crop_size, mode="constant" ), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1.0, neg=1.0, num_samples=1, image_threshold=0 ), ExtractFirstSampledDict(), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), # 移除耗时的旋转增强 ✅ # RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3), RandZoomd(keys=["image", "label"], prob=0.3, min_zoom=0.9, max_zoom=1.1, mode=("trilinear", "nearest")), # 减少概率 RandGaussianNoised(keys=["image"], prob=0.1, mean=0.0, std=0.05), # 减少概率 EnsureTyped(keys=["image", "label"], data_type="tensor"), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), ResizeWithPadOrCropd( # 轻量级缩放 ✅ keys=["image", "label"], spatial_size=resized_size, mode="constant" ), EnsureTyped(keys=["image", "label"], data_type="tensor"), ]) images = sorted(glob(os.path.join(images_dir, "*.nii.gz"))) labels = sorted(glob(os.path.join(labels_dir, "*.nii.gz"))) data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)] train_files, val_files = train_test_split(data, test_size=0.2, random_state=42) # 使用CacheDataset加速数据加载 ✅ cache_dir = "./data_cache" os.makedirs(cache_dir, exist_ok=True) train_ds = CacheDataset( data=train_files, transform=train_transforms, cache_rate=1.0, # 100% 缓存 num_workers=0 ) val_ds = CacheDataset( data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=0 ) # 优化数据加载器 ✅ train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, # 增加工作线程数 prefetch_factor=2, # 预取数据 collate_fn=list_data_collate, pin_memory=True # 使用固定内存 ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, collate_fn=list_data_collate, pin_memory=True ) # ====================== # 模型构建 (优化版) # ====================== model = UNETR( in_channels=1, out_channels=num_classes, img_size=resized_size, # 使用验证集的尺寸,如 (128, 128, 64) feature_size=16, hidden_size=192, mlp_dim=768, num_heads=3, proj_type="perceptron", # ✅ 替代 pos_embed norm_name="batch", res_block=True, dropout_rate=0.1 ).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"模型参数总数: {total_params / 1e6:.2f}M") # ====================== # 损失 + 优化器 # ====================== class_weights = torch.tensor([0.2, 0.3, 0.5]).to(device) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, ce_weight=class_weights, lambda_dice=0.5, lambda_ce=0.5) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5) def lr_lambda(epoch): if epoch < warmup_epochs: return (epoch + 1) / warmup_epochs progress = (epoch - warmup_epochs) / (max_epochs - warmup_epochs) return 0.5 * (1 + np.cos(np.pi * progress)) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) # ====================== # 评估器 # ====================== post_pred = Compose([ Activations(softmax=True), AsDiscrete(to_onehot=num_classes) ]) post_label = Compose([AsDiscrete(to_onehot=num_classes)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes) scaler = GradScaler(enabled=use_amp) # ====================== # 训练循环 (优化版) # ====================== best_metric = -1 best_metric_epoch = -1 train_loss_history = [] val_dice_history = [] os.makedirs("unetr_checkpoints", exist_ok=True) os.makedirs("logs", exist_ok=True) # 预热GPU print("===== GPU预热 =====") dummy_input = torch.randn(1, 1, *crop_size).to(device) for _ in range(10): model(dummy_input) torch.cuda.synchronize() print("\n===== 开始训练 =====") for epoch in range(max_epochs): print(f"\nEpoch {epoch+1}/{max_epochs}") start_time = time.time() model.train() epoch_loss, step = 0, 0 pbar_train = tqdm(total=len(train_loader), desc=f"训练 Epoch {epoch+1}") optimizer.zero_grad() for batch_idx, batch_data in enumerate(train_loader): step += 1 try: inputs = batch_data["image"].to(device, non_blocking=True) labels = batch_data["label"].to(device, non_blocking=True) # 使用自动混合精度 with autocast(enabled=use_amp): outputs = model(inputs) loss = loss_function(outputs, labels) / accumulation_steps # 梯度累积 ✅ if use_amp: scaler.scale(loss).backward() else: loss.backward() # 每accumulation_steps步更新一次权重 if (batch_idx + 1) % accumulation_steps == 0: if use_amp: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad() epoch_loss += loss.item() * accumulation_steps pbar_train.update(1) pbar_train.set_postfix({"loss": f"{loss.item() * accumulation_steps:.4f}"}) # 每10个批次清理显存 if step % 10 == 0: torch.cuda.empty_cache() except RuntimeError as e: if 'CUDA out of memory' in str(e): print("\nCUDA内存不足,跳过该批次") torch.cuda.empty_cache() gc.collect() optimizer.zero_grad() else: print(f"\n训练时发生错误: {str(e)}") continue except Exception as e: print(f"\n训练时发生未知错误: {str(e)}") continue # 处理剩余的梯度 if (batch_idx + 1) % accumulation_steps != 0: if use_amp: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad() pbar_train.close() epoch_loss /= len(train_loader) train_loss_history.append(epoch_loss) print(f"训练平均损失: {epoch_loss:.4f}") scheduler.step() current_lr = optimizer.param_groups[0]['lr'] print(f"当前学习率: {current_lr:.7f}") # 验证 model.eval() dice_vals = [] pbar_val = tqdm(total=len(val_loader), desc=f"验证 Epoch {epoch+1}") with torch.no_grad(): for val_data in val_loader: try: val_images = val_data["image"].to(device, non_blocking=True) val_labels = val_data["label"].to(device, non_blocking=True) # 验证时不使用自动混合精度 val_outputs = model(val_images) val_preds = post_pred(val_outputs.cpu()) val_truth = post_label(val_labels.cpu()) dice_metric(y_pred=[val_preds], y=[val_truth]) metric = dice_metric.aggregate().item() dice_metric.reset() dice_vals.append(metric) pbar_val.update(1) pbar_val.set_postfix({"dice": f"{metric:.4f}"}) except RuntimeError as e: print(f"\n验证时发生错误: {str(e)}") continue except Exception as e: print(f"\n验证时发生未知错误: {str(e)}") continue pbar_val.close() avg_metric = np.mean(dice_vals) if dice_vals else 0.0 val_dice_history.append(avg_metric) print(f"验证平均Dice: {avg_metric:.4f}") epoch_time = time.time() - start_time print(f"Epoch 耗时: {epoch_time:.2f}秒, 平均每批次: {epoch_time/len(train_loader):.2f}秒") if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), f"unetr_checkpoints/best_model_epoch{best_metric_epoch}_dice{best_metric:.4f}.pth") print(f"保存新的最佳模型! Epoch: {best_metric_epoch}, Dice: {best_metric:.4f}") if (epoch + 1) % 10 == 0: torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss, 'dice': avg_metric }, f"unetr_checkpoints/checkpoint_epoch_{epoch+1}.pth") # 绘制训练曲线 (每5个epoch保存一次) if (epoch + 1) % 5 == 0: plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(train_loss_history, label='训练损失') plt.title('训练损失') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.subplot(1, 2, 2) plt.plot(val_dice_history, label='验证Dice', color='orange') plt.title('验证Dice') plt.xlabel('Epoch') plt.ylabel('Dice') plt.legend() plt.tight_layout() plt.savefig("logs/unetr_training_metrics.png") plt.close() # 清理显存 torch.cuda.empty_cache() gc.collect() print(f"\n训练完成! 最佳Dice: {best_metric:.4f} at epoch {best_metric_epoch}")我的代码现在是这样的还是报错啦(covid_seg) (base) liulicheng@ailab-MS-7B79:~/MultiModal_MedSeg_2025$ /home/liulicheng/anaconda3/envs/covid_seg/bin/python /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py ===== 环境信息 ===== Python版本: 3.8.12 | packaged by conda-forge | (default, Sep 29 2021, 19:52:28) [GCC 9.4.0] PyTorch版本: 2.1.0+cu118 MONAI版本: 1.3.2 nibabel版本: 5.2.1 CUDA版本: 11.8 cuDNN版本: 8700 GPU: NVIDIA GeForce GTX 1080 Ti 输入尺寸: resized_size=(128, 128, 64), crop_size=(48, 48, 48) Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [17:55<00:00, 10.34s/it] Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [06:54<00:00, 15.94s/it] 模型参数总数: 8.99M /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.losses.dice DiceCELoss.__init__:ce_weight: Argument `ce_weight` has been deprecated since version 1.2. It will be removed in version 1.4. please use `weight` instead. warn_deprecated(argname, msg, warning_category) ===== GPU预热 ===== Traceback (most recent call last): File "/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py", line 237, in <module> model(dummy_input) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/networks/nets/unetr.py", line 207, in forward x, hidden_states_out = self.vit(x_in) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/networks/nets/vit.py", line 130, in forward x = self.patch_embedding(x) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/networks/blocks/patchembedding.py", line 142, in forward embeddings = x + self.position_embeddings RuntimeError: The size of tensor a (27) must match the size of tensor b (256) at non-singleton dimension 1
最新发布
06-30
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值