【python】openpyxl中错误警告UserWarning: Call to deprecated function的原因

本文介绍了一种新的工作表访问方式,替代了已过时的workbook.get_sheet_by_name方法。推荐使用workbook.worksheets[0]来获取第一个工作表,此方法更简洁且符合最新实践。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

博客已经搬家到“捕获完成”:

https://www.v2python.com

 

 目前建议不适用这个过时的方法:workbook.get_sheet_by_name('Sheet1')

   直接用:sheet = workbook.worksheets[0]

 

    也就是:

    ws = wb["frequency"] 
    等同于 ws2 = wb.get_sheet_by_name('frequency')

(covid_seg) (base) liulicheng@ailab-MS-7B79:~/MultiModal_MedSeg_2025$ /home/liulicheng/anaconda3/envs/covid_seg/bin/python /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py ===== 环境信息 ===== Python版本: 3.8.12 | packaged by conda-forge | (default, Sep 29 2021, 19:52:28) [GCC 9.4.0] PyTorch版本: 2.1.0+cu118 MONAI版本: 1.3.2 nibabel版本: 5.2.1 CUDA版本: 11.8 cuDNN版本: 8700 GPU: NVIDIA GeForce GTX 1080 Ti 输入尺寸: resized_size=(128, 128, 64), crop_size=(48, 48, 48) Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [18:16<00:00, 10.55s/it] Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [06:43<00:00, 15.54s/it] 模型参数总数: 8.95M /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.losses.dice DiceCELoss.__init__:ce_weight: Argument `ce_weight` has been deprecated since version 1.2. It will be removed in version 1.4. please use `weight` instead. warn_deprecated(argname, msg, warning_category) ===== GPU预热 ===== ===== 开始训练 ===== Epoch 1/200 训练 Epoch 1: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( 训练 Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:09<00:00, 5.42it/s, loss=0.7731] 训练平均损失: 0.7909 当前学习率: 0.0000200 验证 Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.99it/s, dice=0.0469] 验证平均Dice: 0.0288 Epoch 耗时: 11.96秒, 平均每批次: 0.23秒 保存新的最佳模型! Epoch: 1, Dice: 0.0288 Epoch 2/200 训练 Epoch 2: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( 训练 Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.92it/s, loss=0.7215] 训练平均损失: 0.7502 当前学习率: 0.0000300 验证 Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.71it/s, dice=0.0469] 验证平均Dice: 0.0288 Epoch 耗时: 9.00秒, 平均每批次: 0.17秒 Epoch 3/200 训练 Epoch 3: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( 训练 Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.96it/s, loss=0.6815] 训练平均损失: 0.7063 当前学习率: 0.0000400 验证 Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.48it/s, dice=0.0469] 验证平均Dice: 0.0288 Epoch 耗时: 9.02秒, 平均每批次: 0.17秒 Epoch 4/200 训练 Epoch 4: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( 训练 Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.59it/s, loss=0.6309] 训练平均损失: 0.6614 当前学习率: 0.0000500 验证 Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.85it/s, dice=0.0469] 验证平均Dice: 0.0288 Epoch 耗时: 9.25秒, 平均每批次: 0.18秒 Epoch 5/200 训练 Epoch 5: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( 训练 Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.93it/s, loss=0.6124] 训练平均损失: 0.6249 当前学习率: 0.0000600 验证 Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.62it/s, dice=0.0469] 验证平均Dice: 0.0288 Epoch 耗时: 9.01秒, 平均每批次: 0.17秒 /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 35757 (\N{CJK UNIFIED IDEOGRAPH-8BAD}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 32451 (\N{CJK UNIFIED IDEOGRAPH-7EC3}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 25439 (\N{CJK UNIFIED IDEOGRAPH-635F}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 22833 (\N{CJK UNIFIED IDEOGRAPH-5931}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 39564 (\N{CJK UNIFIED IDEOGRAPH-9A8C}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 35777 (\N{CJK UNIFIED IDEOGRAPH-8BC1}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 35757 (\N{CJK UNIFIED IDEOGRAPH-8BAD}) missing from current font. plt.savefig("logs/unetr_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 32451 (\N{CJK UNIFIED IDEOGRAPH-7EC3}) missing from current font. plt.savefig("logs/unetr_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 25439 (\N{CJK UNIFIED IDEOGRAPH-635F}) missing from current font. plt.savefig("logs/unetr_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 22833 (\N{CJK UNIFIED IDEOGRAPH-5931}) missing from current font. plt.savefig("logs/unetr_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 39564 (\N{CJK UNIFIED IDEOGRAPH-9A8C}) missing from current font. plt.savefig("logs/unetr_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 35777 (\N{CJK UNIFIED IDEOGRAPH-8BC1}) missing from current font. plt.savefig("logs/unetr_training_metrics.png") Epoch 6/200 训练 Epoch 6: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( 训练 Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.90it/s, loss=0.5874] 训练平均损失: 0.5966 当前学习率: 0.0000700 验证 Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.82it/s, dice=0.0469] 验证平均Dice: 0.0288 Epoch 耗时: 8.99秒, 平均每批次: 0.17秒 Epoch 7/200 训练 Epoch 7: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( /home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0. warnings.warn( 训练 Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.81it/s, loss=0.5655] 训练平均损失: 0.5751 当前学习率: 0.0000800 验证 Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 11.14it/s, dice=0.0469] 验证平均Dice: 0.0288 Epoch 耗时: 9.00秒, 平均每批次: 0.17秒这是代码# ====================== # UNETR 训练脚本 # ====================== import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd, EnsureTyped, Activations, AsDiscrete, ResizeWithPadOrCropd, RandZoomd, RandGaussianNoised ) from monai.data import list_data_collate, CacheDataset # 使用CacheDataset from monai.networks.nets import UNETR from monai.losses import DiceCELoss from monai.metrics import DiceMetric from glob import glob from sklearn.model_selection import train_test_split from torch.optim.lr_scheduler import LambdaLR from tqdm import tqdm from torch.cuda.amp import GradScaler, autocast import matplotlib.pyplot as plt import gc import nibabel as nib import sys import monai import time # 自定义Transform:用于把RandCropByPosNegLabeld返回的list转成Tensor class ExtractFirstSampledDict(monai.transforms.Transform): def __call__(self, data): out = {} for k, v in data.items(): if isinstance(v, list) and len(v) == 1: out[k] = v[0] else: out[k] = v return out # ====================== # 配置参数 (优化版) # ====================== root_dir = "datasets/LiTS/processed" images_dir = os.path.join(root_dir, "images") labels_dir = os.path.join(root_dir, "labels") max_epochs = 200 batch_size = 2 # 增大批次大小 learning_rate = 1e-4 num_classes = 3 warmup_epochs = 10 use_amp = False # 启用AMP加速 ✅ accumulation_steps = 2 # 梯度累积步数 # 禁用MetaTensor以避免decollate错误 os.environ["MONAI_USE_META_DICT"] = "0" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 启用cuDNN基准测试 if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True # 打印环境信息 print("===== 环境信息 =====") print(f"Python版本: {sys.version}") print(f"PyTorch版本: {torch.__version__}") print(f"MONAI版本: {monai.__version__}") print(f"nibabel版本: {nib.__version__}") if torch.cuda.is_available(): print(f"CUDA版本: {torch.version.cuda}") print(f"cuDNN版本: {torch.backends.cudnn.version()}") print(f"GPU: {torch.cuda.get_device_name(0)}") # 尺寸设置 - 确保能被16整除 def get_valid_size(size, divisor=16): return tuple([max(divisor, (s // divisor) * divisor) for s in size]) base_size = (128, 128, 64) resized_size = get_valid_size(base_size) crop_size = get_valid_size((48, 48, 48)) # 减小尺寸以节省显存 print(f"输入尺寸: resized_size={resized_size}, crop_size={crop_size}") # ====================== # 数据预处理 (优化版) # ====================== train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), ResizeWithPadOrCropd( # 轻量级缩放 ✅ keys=["image", "label"], spatial_size=crop_size, mode="constant" ), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1.0, neg=1.0, num_samples=1, image_threshold=0 ), ExtractFirstSampledDict(), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), # 移除耗时的旋转增强 ✅ # RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3), RandZoomd(keys=["image", "label"], prob=0.3, min_zoom=0.9, max_zoom=1.1, mode=("trilinear", "nearest")), # 减少概率 RandGaussianNoised(keys=["image"], prob=0.1, mean=0.0, std=0.05), # 减少概率 EnsureTyped(keys=["image", "label"], data_type="tensor"), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), ResizeWithPadOrCropd( # 轻量级缩放 ✅ keys=["image", "label"], spatial_size=crop_size, mode="constant" ), EnsureTyped(keys=["image", "label"], data_type="tensor"), ]) images = sorted(glob(os.path.join(images_dir, "*.nii.gz"))) labels = sorted(glob(os.path.join(labels_dir, "*.nii.gz"))) data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)] train_files, val_files = train_test_split(data, test_size=0.2, random_state=42) # 使用CacheDataset加速数据加载 ✅ cache_dir = "./data_cache" os.makedirs(cache_dir, exist_ok=True) train_ds = CacheDataset( data=train_files, transform=train_transforms, cache_rate=1.0, # 100% 缓存 num_workers=0 ) val_ds = CacheDataset( data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=0 ) # 优化数据加载器 ✅ train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, # 增加工作线程数 prefetch_factor=2, # 预取数据 collate_fn=list_data_collate, pin_memory=True # 使用固定内存 ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, collate_fn=list_data_collate, pin_memory=True ) # ====================== # 模型构建 (优化版) # ====================== model = UNETR( in_channels=1, out_channels=num_classes, img_size=crop_size, # 使用验证集的尺寸,如 (128, 128, 64) feature_size=16, hidden_size=192, mlp_dim=768, num_heads=3, proj_type="perceptron", # ✅ 替代 pos_embed norm_name="batch", res_block=True, dropout_rate=0.1 ).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"模型参数总数: {total_params / 1e6:.2f}M") # ====================== # 损失 + 优化器 # ====================== class_weights = torch.tensor([0.2, 0.3, 0.5]).to(device) loss_function = DiceCELoss(to_onehot_y=True, softmax=True, ce_weight=class_weights, lambda_dice=0.5, lambda_ce=0.5) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5) def lr_lambda(epoch): if epoch < warmup_epochs: return (epoch + 1) / warmup_epochs progress = (epoch - warmup_epochs) / (max_epochs - warmup_epochs) return 0.5 * (1 + np.cos(np.pi * progress)) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) # ====================== # 评估器 # ====================== post_pred = Compose([ Activations(softmax=True), AsDiscrete(to_onehot=num_classes) ]) post_label = Compose([AsDiscrete(to_onehot=num_classes)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes) scaler = GradScaler(enabled=use_amp) # ====================== # 训练循环 (优化版) # ====================== best_metric = -1 best_metric_epoch = -1 train_loss_history = [] val_dice_history = [] os.makedirs("unetr_checkpoints", exist_ok=True) os.makedirs("logs", exist_ok=True) # 预热GPU print("===== GPU预热 =====") dummy_input = torch.randn(1, 1, *crop_size).to(device) for _ in range(10): model(dummy_input) torch.cuda.synchronize() print("\n===== 开始训练 =====") for epoch in range(max_epochs): print(f"\nEpoch {epoch+1}/{max_epochs}") start_time = time.time() model.train() epoch_loss, step = 0, 0 pbar_train = tqdm(total=len(train_loader), desc=f"训练 Epoch {epoch+1}") optimizer.zero_grad() for batch_idx, batch_data in enumerate(train_loader): step += 1 try: inputs = batch_data["image"].to(device, non_blocking=True) labels = batch_data["label"].to(device, non_blocking=True) # 使用自动混合精度 with autocast(enabled=use_amp): outputs = model(inputs) loss = loss_function(outputs, labels) / accumulation_steps # 梯度累积 ✅ if use_amp: scaler.scale(loss).backward() else: loss.backward() # 每accumulation_steps步更新一次权重 if (batch_idx + 1) % accumulation_steps == 0: if use_amp: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad() epoch_loss += loss.item() * accumulation_steps pbar_train.update(1) pbar_train.set_postfix({"loss": f"{loss.item() * accumulation_steps:.4f}"}) # 每10个批次清理显存 if step % 10 == 0: torch.cuda.empty_cache() except RuntimeError as e: if 'CUDA out of memory' in str(e): print("\nCUDA内存不足,跳过该批次") torch.cuda.empty_cache() gc.collect() optimizer.zero_grad() else: print(f"\n训练时发生错误: {str(e)}") continue except Exception as e: print(f"\n训练时发生未知错误: {str(e)}") continue # 处理剩余的梯度 if (batch_idx + 1) % accumulation_steps != 0: if use_amp: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad() pbar_train.close() epoch_loss /= len(train_loader) train_loss_history.append(epoch_loss) print(f"训练平均损失: {epoch_loss:.4f}") scheduler.step() current_lr = optimizer.param_groups[0]['lr'] print(f"当前学习率: {current_lr:.7f}") # 验证 model.eval() dice_vals = [] pbar_val = tqdm(total=len(val_loader), desc=f"验证 Epoch {epoch+1}") with torch.no_grad(): for val_data in val_loader: try: val_images = val_data["image"].to(device, non_blocking=True) val_labels = val_data["label"].to(device, non_blocking=True) # 验证时不使用自动混合精度 val_outputs = model(val_images) val_preds = post_pred(val_outputs.cpu()) val_truth = post_label(val_labels.cpu()) dice_metric(y_pred=[val_preds], y=[val_truth]) metric = dice_metric.aggregate().item() dice_metric.reset() dice_vals.append(metric) pbar_val.update(1) pbar_val.set_postfix({"dice": f"{metric:.4f}"}) except RuntimeError as e: print(f"\n验证时发生错误: {str(e)}") continue except Exception as e: print(f"\n验证时发生未知错误: {str(e)}") continue pbar_val.close() avg_metric = np.mean(dice_vals) if dice_vals else 0.0 val_dice_history.append(avg_metric) print(f"验证平均Dice: {avg_metric:.4f}") epoch_time = time.time() - start_time print(f"Epoch 耗时: {epoch_time:.2f}秒, 平均每批次: {epoch_time/len(train_loader):.2f}秒") if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), f"unetr_checkpoints/best_model_epoch{best_metric_epoch}_dice{best_metric:.4f}.pth") print(f"保存新的最佳模型! Epoch: {best_metric_epoch}, Dice: {best_metric:.4f}") if (epoch + 1) % 10 == 0: torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss, 'dice': avg_metric }, f"unetr_checkpoints/checkpoint_epoch_{epoch+1}.pth") # 绘制训练曲线 (每5个epoch保存一次) if (epoch + 1) % 5 == 0: plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(train_loss_history, label='训练损失') plt.title('训练损失') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.subplot(1, 2, 2) plt.plot(val_dice_history, label='验证Dice', color='orange') plt.title('验证Dice') plt.xlabel('Epoch') plt.ylabel('Dice') plt.legend() plt.tight_layout() plt.savefig("logs/unetr_training_metrics.png") plt.close() # 清理显存 torch.cuda.empty_cache() gc.collect() print(f"\n训练完成! 最佳Dice: {best_metric:.4f} at epoch {best_metric_epoch}")感觉这个代码跑的结果有问题
最新发布
07-01
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值