nnUnet(代码)-训练部分

这篇博客深入解析了nnUNet训练过程,包括训练计划的获取、数据增强参数初始化、五折交叉验证、数据加载、网络与优化器初始化。nnUNetTrainer版本一采用DC_and_CE_loss作为损失函数,使用Adam优化器和基于损失平均值的学习率调度。版本二引入了深监督,通过权重调整强化损失函数,并改用SGD优化器及自定义学习率下降策略。此外,还讨论了数据增强参数的变化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

学习目标:逐步分析nnunet训练部分

学习内容:training部分

· 拿到训练plans(计划)
· 初始化数据增强参数
· 采用五折交叉验证
· dataset与dataloader/数据加载过程
· 初始化网络
· 初始化优化器与学习率函数

1.nnUNetTrainer(版本一的训练方法)

··· 损失函数:

self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})

··· 优化器与学习率函数:
优化器用adam
学习率的调整是用的损失函数的加权平均值来判断是否变动的方法

    def initialize_optimizer_and_scheduler(self):
        assert self.network is not None, "self.initialize_network must be called first"
        self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                          amsgrad=True)
        self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
                                                           patience=self.lr_scheduler_patience,
                                                           verbose=True, threshold=1e-3,
                                                           threshold_mode="abs")
# 学习率函数设置
self.train_loss_MA_alpha = 0.93  # alpha * old + (1-alpha) * new

    def update_train_loss_MA(self):
        if self.train_loss_MA is None:
            self.train_loss_MA = self.all_tr_losses[-1]
        else:
            self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
                                 self.all_tr_losses[-1]
# lr scheduler is updated with moving average val loss. should be more robust
self.lr_scheduler.step(self.train_loss_MA)

2.nnUNetTrainerV2(版本二的训练方法)

··· 加强了损失函数(深监督):
还是原来损失,但是添加了一个策略:给每层的损失加一个权重,分辨率越高的权重越大,简单说就是针对中间隐藏层特征透明度不高以及深层网络中浅层以及中间网络难以训练的问题。

################# 封装损失函数进入深度学习(深监督) ############
        # 需要知道网络深度
        # net_numpool = len(self.plans['pool_op_kernel_sizes'])

        # 我们给每个输出一个权重,该权重随着分辨率的降低呈指数递减(除以2)
        # 这使得更高的分辨率输出在损失中有更大的权重
        weights = np.array([1 / (2 ** i) for i in range(self.net_numpool)])

        # 我们不使用最低的2个输出。标准化权重,使其总和为1
        mask = np.array([True] + [True if i < self.net_numpool - 1 else False for i in range(1, self.net_numpool)])
        weights[~mask] = 0
        weights = weights / weights.sum()
        self.ds_loss_weights = weights

        # 封装损失函数
        self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)

··· 重写了优化器与学习率函数
采用SGD与自定义的学习率下降函数

    def initialize_optimizer_and_scheduler(self):
        assert self.network is not None, "self.initialize_network must be called first"
        self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
                                         momentum=0.99, nesterov=True)
        self.lr_scheduler = None
 def maybe_update_lr(self, epoch=None):

        if epoch is None:
            ep = self.epoch + 1
        else:
            ep = epoch
        self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
        
def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
    return initial_lr * (1 - epoch / max_epochs)**exponent

··· 重写了数据增强参数

3.后面还有DP等等三四个版本,是基于版本二改变的,主要是通过混合精度进行训练增加训练速度

<think>好的,我需要帮助用户在nnUNet框架中关闭数据增强中的翻转功能。首先,我得回忆一下nnUNet的数据增强机制。nnUNet是一个用于医学图像分割的框架,其数据增强通常在训练时应用,包括旋转、缩放、翻转等操作。用户想要禁用翻转功能,可能需要修改相关的配置文件或代码。 根据用户提供的引用,其他框架如React、Android、Spring Boot都有不同的禁用方法,但nnUNet可能有自己的配置方式。我需要查找nnUNet的文档或源码,看看数据增强的具体实现。通常在深度学习框架中,数据增强的参数可能在训练脚本或配置文件中设置。 nnUNet的数据增强配置可能在`nnunet/training/data_augmentation`目录下的某个文件中。例如,默认的数据增强参数可能在`default_data_augmentation_params`里。翻转功能可能由`do_mirroring`或`mirror_axes`参数控制。用户需要将`do_mirroring`设为False,或者调整`mirror_axes`为空列表。 此外,用户可能需要修改计划文件(如plans文件),这些文件通常存储在`nnUNet_preprocessed`目录下的任务文件夹中。修改这些文件中的相应参数可以禁用翻转。需要注意的是,直接修改代码可能会影响其他功能,所以最好通过配置文件调整。 另外,用户可能需要重新运行预处理步骤,以确保更改生效。因为nnUNet在预处理阶段可能会根据配置生成增强策略,所以修改后需要重新处理数据。 总结步骤:找到数据增强参数配置,关闭翻转相关设置,调整计划文件,重新预处理数据。需要确保用户正确修改了所有相关的地方,避免遗漏导致功能未禁用。</think>在nnUNet框架中禁用数据翻转功能需要通过修改训练配置文件或数据增强参数实现。以下是具体操作步骤: ### 方法一:修改数据增强参数 1. 在训练配置文件中定位`data_augmentation`参数组 2. 将镜像翻转参数设置为False: ```python # 在nnunet/training/network_training/nnUNetTrainer.py中修改 self.data_aug_params["do_mirroring"] = False # 关闭所有轴向翻转 self.data_aug_params["mirror_axes"] = () # 空元组表示不进行任何轴向翻转 ``` ### 方法二:调整计划文件 1. 打开对应任务的plans文件(位于`nnUNet_preprocessed/TASKXXX`目录) 2. 修改数据增强配置部分: ```json "data_augmentation": { "mirror": {"prob_per_axis": 0.0}, # 将各轴向翻转概率设为0 "rotation": {...} # 保留其他增强配置 } ``` ### 方法三:自定义训练器(推荐) ```python from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer class CustomTrainer(nnUNetTrainer): def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None): super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage) # 禁用所有翻转增强 self.data_aug_params["do_mirroring"] = False self.data_aug_params["mirror_axes"] = () # 训练时指定自定义训练nnUNet_train CONFIGURATION -tr CustomTrainer ... ``` 修改后需要重新运行数据预处理: ```bash nnUNet_plan_and_preprocess -t TASK_ID --verify_dataset_integrity ``` 需要注意: 1. 禁用数据增强可能降低模型泛化能力,建议保留其他增强手段 2. 不同版本nnUNet的配置文件路径可能有所变化,建议查阅对应版本的文档[^4] 3. 如果使用预训练模型,需要从头开始训练才能应用新的增强策略
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值