(covid_seg) (base) liulicheng@ailab-MS-7B79:~/MultiModal_MedSeg_2025$ /home/liulicheng/anaconda3/envs/covid_seg/bin/python /home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py
===== 环境信息 =====
Python版本: 3.8.12 | packaged by conda-forge | (default, Sep 29 2021, 19:52:28)
[GCC 9.4.0]
PyTorch版本: 2.1.0+cu118
MONAI版本: 1.3.2
nibabel版本: 5.2.1
CUDA版本: 11.8
cuDNN版本: 8700
GPU: NVIDIA GeForce GTX 1080 Ti
输入尺寸: resized_size=(128, 128, 64), crop_size=(48, 48, 48)
Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [18:16<00:00, 10.55s/it]
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [06:43<00:00, 15.54s/it]
模型参数总数: 8.95M
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/utils/deprecate_utils.py:221: FutureWarning: monai.losses.dice DiceCELoss.__init__:ce_weight: Argument `ce_weight` has been deprecated since version 1.2. It will be removed in version 1.4. please use `weight` instead.
warn_deprecated(argname, msg, warning_category)
===== GPU预热 =====
===== 开始训练 =====
Epoch 1/200
训练 Epoch 1: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
训练 Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:09<00:00, 5.42it/s, loss=0.7731]
训练平均损失: 0.7909
当前学习率: 0.0000200
验证 Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.99it/s, dice=0.0469]
验证平均Dice: 0.0288
Epoch 耗时: 11.96秒, 平均每批次: 0.23秒
保存新的最佳模型! Epoch: 1, Dice: 0.0288
Epoch 2/200
训练 Epoch 2: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
训练 Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.92it/s, loss=0.7215]
训练平均损失: 0.7502
当前学习率: 0.0000300
验证 Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.71it/s, dice=0.0469]
验证平均Dice: 0.0288
Epoch 耗时: 9.00秒, 平均每批次: 0.17秒
Epoch 3/200
训练 Epoch 3: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
训练 Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.96it/s, loss=0.6815]
训练平均损失: 0.7063
当前学习率: 0.0000400
验证 Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.48it/s, dice=0.0469]
验证平均Dice: 0.0288
Epoch 耗时: 9.02秒, 平均每批次: 0.17秒
Epoch 4/200
训练 Epoch 4: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
训练 Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.59it/s, loss=0.6309]
训练平均损失: 0.6614
当前学习率: 0.0000500
验证 Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.85it/s, dice=0.0469]
验证平均Dice: 0.0288
Epoch 耗时: 9.25秒, 平均每批次: 0.18秒
Epoch 5/200
训练 Epoch 5: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
训练 Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.93it/s, loss=0.6124]
训练平均损失: 0.6249
当前学习率: 0.0000600
验证 Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.62it/s, dice=0.0469]
验证平均Dice: 0.0288
Epoch 耗时: 9.01秒, 平均每批次: 0.17秒
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 35757 (\N{CJK UNIFIED IDEOGRAPH-8BAD}) missing from current font.
plt.tight_layout()
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 32451 (\N{CJK UNIFIED IDEOGRAPH-7EC3}) missing from current font.
plt.tight_layout()
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 25439 (\N{CJK UNIFIED IDEOGRAPH-635F}) missing from current font.
plt.tight_layout()
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 22833 (\N{CJK UNIFIED IDEOGRAPH-5931}) missing from current font.
plt.tight_layout()
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 39564 (\N{CJK UNIFIED IDEOGRAPH-9A8C}) missing from current font.
plt.tight_layout()
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:386: UserWarning: Glyph 35777 (\N{CJK UNIFIED IDEOGRAPH-8BC1}) missing from current font.
plt.tight_layout()
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 35757 (\N{CJK UNIFIED IDEOGRAPH-8BAD}) missing from current font.
plt.savefig("logs/unetr_training_metrics.png")
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 32451 (\N{CJK UNIFIED IDEOGRAPH-7EC3}) missing from current font.
plt.savefig("logs/unetr_training_metrics.png")
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 25439 (\N{CJK UNIFIED IDEOGRAPH-635F}) missing from current font.
plt.savefig("logs/unetr_training_metrics.png")
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 22833 (\N{CJK UNIFIED IDEOGRAPH-5931}) missing from current font.
plt.savefig("logs/unetr_training_metrics.png")
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 39564 (\N{CJK UNIFIED IDEOGRAPH-9A8C}) missing from current font.
plt.savefig("logs/unetr_training_metrics.png")
/home/liulicheng/MultiModal_MedSeg_2025/train/train_unetr.py:387: UserWarning: Glyph 35777 (\N{CJK UNIFIED IDEOGRAPH-8BC1}) missing from current font.
plt.savefig("logs/unetr_training_metrics.png")
Epoch 6/200
训练 Epoch 6: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
训练 Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.90it/s, loss=0.5874]
训练平均损失: 0.5966
当前学习率: 0.0000700
验证 Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.82it/s, dice=0.0469]
验证平均Dice: 0.0288
Epoch 耗时: 8.99秒, 平均每批次: 0.17秒
Epoch 7/200
训练 Epoch 7: 0%| | 0/52 [00:00<?, ?it/s]/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
/home/liulicheng/anaconda3/envs/covid_seg/lib/python3.8/site-packages/monai/transforms/utils.py:606: UserWarning: Num foregrounds 0, Num backgrounds 110592, unable to generate class balanced samples, setting `pos_ratio` to 0.
warnings.warn(
训练 Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:06<00:00, 7.81it/s, loss=0.5655]
训练平均损失: 0.5751
当前学习率: 0.0000800
验证 Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 11.14it/s, dice=0.0469]
验证平均Dice: 0.0288
Epoch 耗时: 9.00秒, 平均每批次: 0.17秒这是代码# ======================
# UNETR 训练脚本
# ======================
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged,
RandCropByPosNegLabeld, RandFlipd, EnsureTyped, Activations,
AsDiscrete, ResizeWithPadOrCropd, RandZoomd, RandGaussianNoised
)
from monai.data import list_data_collate, CacheDataset # 使用CacheDataset
from monai.networks.nets import UNETR
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from glob import glob
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
import matplotlib.pyplot as plt
import gc
import nibabel as nib
import sys
import monai
import time
# 自定义Transform:用于把RandCropByPosNegLabeld返回的list转成Tensor
class ExtractFirstSampledDict(monai.transforms.Transform):
def __call__(self, data):
out = {}
for k, v in data.items():
if isinstance(v, list) and len(v) == 1:
out[k] = v[0]
else:
out[k] = v
return out
# ======================
# 配置参数 (优化版)
# ======================
root_dir = "datasets/LiTS/processed"
images_dir = os.path.join(root_dir, "images")
labels_dir = os.path.join(root_dir, "labels")
max_epochs = 200
batch_size = 2 # 增大批次大小
learning_rate = 1e-4
num_classes = 3
warmup_epochs = 10
use_amp = False # 启用AMP加速 ✅
accumulation_steps = 2 # 梯度累积步数
# 禁用MetaTensor以避免decollate错误
os.environ["MONAI_USE_META_DICT"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 启用cuDNN基准测试
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
# 打印环境信息
print("===== 环境信息 =====")
print(f"Python版本: {sys.version}")
print(f"PyTorch版本: {torch.__version__}")
print(f"MONAI版本: {monai.__version__}")
print(f"nibabel版本: {nib.__version__}")
if torch.cuda.is_available():
print(f"CUDA版本: {torch.version.cuda}")
print(f"cuDNN版本: {torch.backends.cudnn.version()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
# 尺寸设置 - 确保能被16整除
def get_valid_size(size, divisor=16):
return tuple([max(divisor, (s // divisor) * divisor) for s in size])
base_size = (128, 128, 64)
resized_size = get_valid_size(base_size)
crop_size = get_valid_size((48, 48, 48)) # 减小尺寸以节省显存
print(f"输入尺寸: resized_size={resized_size}, crop_size={crop_size}")
# ======================
# 数据预处理 (优化版)
# ======================
train_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
ResizeWithPadOrCropd( # 轻量级缩放 ✅
keys=["image", "label"],
spatial_size=crop_size,
mode="constant"
),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=crop_size,
pos=1.0,
neg=1.0,
num_samples=1,
image_threshold=0
),
ExtractFirstSampledDict(),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
# 移除耗时的旋转增强 ✅
# RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3),
RandZoomd(keys=["image", "label"], prob=0.3, min_zoom=0.9, max_zoom=1.1, mode=("trilinear", "nearest")), # 减少概率
RandGaussianNoised(keys=["image"], prob=0.1, mean=0.0, std=0.05), # 减少概率
EnsureTyped(keys=["image", "label"], data_type="tensor"),
])
val_transforms = Compose([
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),
ResizeWithPadOrCropd( # 轻量级缩放 ✅
keys=["image", "label"],
spatial_size=crop_size,
mode="constant"
),
EnsureTyped(keys=["image", "label"], data_type="tensor"),
])
images = sorted(glob(os.path.join(images_dir, "*.nii.gz")))
labels = sorted(glob(os.path.join(labels_dir, "*.nii.gz")))
data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)]
train_files, val_files = train_test_split(data, test_size=0.2, random_state=42)
# 使用CacheDataset加速数据加载 ✅
cache_dir = "./data_cache"
os.makedirs(cache_dir, exist_ok=True)
train_ds = CacheDataset(
data=train_files,
transform=train_transforms,
cache_rate=1.0, # 100% 缓存
num_workers=0
)
val_ds = CacheDataset(
data=val_files,
transform=val_transforms,
cache_rate=1.0,
num_workers=0
)
# 优化数据加载器 ✅
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=4, # 增加工作线程数
prefetch_factor=2, # 预取数据
collate_fn=list_data_collate,
pin_memory=True # 使用固定内存
)
val_loader = DataLoader(
val_ds,
batch_size=1,
shuffle=False,
num_workers=2,
collate_fn=list_data_collate,
pin_memory=True
)
# ======================
# 模型构建 (优化版)
# ======================
model = UNETR(
in_channels=1,
out_channels=num_classes,
img_size=crop_size, # 使用验证集的尺寸,如 (128, 128, 64)
feature_size=16,
hidden_size=192,
mlp_dim=768,
num_heads=3,
proj_type="perceptron", # ✅ 替代 pos_embed
norm_name="batch",
res_block=True,
dropout_rate=0.1
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"模型参数总数: {total_params / 1e6:.2f}M")
# ======================
# 损失 + 优化器
# ======================
class_weights = torch.tensor([0.2, 0.3, 0.5]).to(device)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True, ce_weight=class_weights, lambda_dice=0.5, lambda_ce=0.5)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
def lr_lambda(epoch):
if epoch < warmup_epochs:
return (epoch + 1) / warmup_epochs
progress = (epoch - warmup_epochs) / (max_epochs - warmup_epochs)
return 0.5 * (1 + np.cos(np.pi * progress))
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
# ======================
# 评估器
# ======================
post_pred = Compose([
Activations(softmax=True),
AsDiscrete(to_onehot=num_classes)
])
post_label = Compose([AsDiscrete(to_onehot=num_classes)])
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes)
scaler = GradScaler(enabled=use_amp)
# ======================
# 训练循环 (优化版)
# ======================
best_metric = -1
best_metric_epoch = -1
train_loss_history = []
val_dice_history = []
os.makedirs("unetr_checkpoints", exist_ok=True)
os.makedirs("logs", exist_ok=True)
# 预热GPU
print("===== GPU预热 =====")
dummy_input = torch.randn(1, 1, *crop_size).to(device)
for _ in range(10):
model(dummy_input)
torch.cuda.synchronize()
print("\n===== 开始训练 =====")
for epoch in range(max_epochs):
print(f"\nEpoch {epoch+1}/{max_epochs}")
start_time = time.time()
model.train()
epoch_loss, step = 0, 0
pbar_train = tqdm(total=len(train_loader), desc=f"训练 Epoch {epoch+1}")
optimizer.zero_grad()
for batch_idx, batch_data in enumerate(train_loader):
step += 1
try:
inputs = batch_data["image"].to(device, non_blocking=True)
labels = batch_data["label"].to(device, non_blocking=True)
# 使用自动混合精度
with autocast(enabled=use_amp):
outputs = model(inputs)
loss = loss_function(outputs, labels) / accumulation_steps
# 梯度累积 ✅
if use_amp:
scaler.scale(loss).backward()
else:
loss.backward()
# 每accumulation_steps步更新一次权重
if (batch_idx + 1) % accumulation_steps == 0:
if use_amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item() * accumulation_steps
pbar_train.update(1)
pbar_train.set_postfix({"loss": f"{loss.item() * accumulation_steps:.4f}"})
# 每10个批次清理显存
if step % 10 == 0:
torch.cuda.empty_cache()
except RuntimeError as e:
if 'CUDA out of memory' in str(e):
print("\nCUDA内存不足,跳过该批次")
torch.cuda.empty_cache()
gc.collect()
optimizer.zero_grad()
else:
print(f"\n训练时发生错误: {str(e)}")
continue
except Exception as e:
print(f"\n训练时发生未知错误: {str(e)}")
continue
# 处理剩余的梯度
if (batch_idx + 1) % accumulation_steps != 0:
if use_amp:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()
pbar_train.close()
epoch_loss /= len(train_loader)
train_loss_history.append(epoch_loss)
print(f"训练平均损失: {epoch_loss:.4f}")
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
print(f"当前学习率: {current_lr:.7f}")
# 验证
model.eval()
dice_vals = []
pbar_val = tqdm(total=len(val_loader), desc=f"验证 Epoch {epoch+1}")
with torch.no_grad():
for val_data in val_loader:
try:
val_images = val_data["image"].to(device, non_blocking=True)
val_labels = val_data["label"].to(device, non_blocking=True)
# 验证时不使用自动混合精度
val_outputs = model(val_images)
val_preds = post_pred(val_outputs.cpu())
val_truth = post_label(val_labels.cpu())
dice_metric(y_pred=[val_preds], y=[val_truth])
metric = dice_metric.aggregate().item()
dice_metric.reset()
dice_vals.append(metric)
pbar_val.update(1)
pbar_val.set_postfix({"dice": f"{metric:.4f}"})
except RuntimeError as e:
print(f"\n验证时发生错误: {str(e)}")
continue
except Exception as e:
print(f"\n验证时发生未知错误: {str(e)}")
continue
pbar_val.close()
avg_metric = np.mean(dice_vals) if dice_vals else 0.0
val_dice_history.append(avg_metric)
print(f"验证平均Dice: {avg_metric:.4f}")
epoch_time = time.time() - start_time
print(f"Epoch 耗时: {epoch_time:.2f}秒, 平均每批次: {epoch_time/len(train_loader):.2f}秒")
if avg_metric > best_metric:
best_metric = avg_metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), f"unetr_checkpoints/best_model_epoch{best_metric_epoch}_dice{best_metric:.4f}.pth")
print(f"保存新的最佳模型! Epoch: {best_metric_epoch}, Dice: {best_metric:.4f}")
if (epoch + 1) % 10 == 0:
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': epoch_loss,
'dice': avg_metric
}, f"unetr_checkpoints/checkpoint_epoch_{epoch+1}.pth")
# 绘制训练曲线 (每5个epoch保存一次)
if (epoch + 1) % 5 == 0:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(train_loss_history, label='训练损失')
plt.title('训练损失')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(val_dice_history, label='验证Dice', color='orange')
plt.title('验证Dice')
plt.xlabel('Epoch')
plt.ylabel('Dice')
plt.legend()
plt.tight_layout()
plt.savefig("logs/unetr_training_metrics.png")
plt.close()
# 清理显存
torch.cuda.empty_cache()
gc.collect()
print(f"\n训练完成! 最佳Dice: {best_metric:.4f} at epoch {best_metric_epoch}")感觉这个代码跑的结果有问题
最新发布