nnU-Net源码剖写解读

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

nnU-Net已经称为了医学图像分割的backbone是现有医学分割模型的集大成者,它已其简单的训练流程和简单U-net模型就可达到sota的水平。我们将对nnUnetV2版本的源码进行剖析,理解和分析这个框架的原理,并希望在学习的后期能够修改源码作出创新。


nnU-Net分析结构
我们知道nnU-NetV2的操作主要分为3部分

  1. 预处理,将图像按照自适应方法处理至统一标准
  2. 训练,对网络进行训练
  3. 推理,通过训练权重预测标签

nnUNetV2的结构

plan_and_preprocess_entry()
	P1. extract_fingerprints
	P2. plan_experiments
	P3. preprocess
		cropping
		normalization
		preprocessors
		resampling
training
	P4. dataloader
		base
		2d
		3d
		nnunet_dataset
	P5. data_augmentation
		compute_initial_patch_size
		custom_transforms
nnUNetTrainer
	P6. initialize
			including network architectures
	P7. train one epoch
	P8. validation
	P9.loss
inference
	P10. predict_from_raw_data
	P11. slideing window
	P12. export prediction
	P13. find best configuration
	P14. postprocessing
dynamic network
	P15. building blocks + resnet
	P16. building blocks + vgg

二、分布剖析

plan_and_preprocess_entry()

请添加图片描述

1. extract_fingerprints_entry

extract_fingerprints的目的是获取数据集的“指纹”即数据集的统计学参数。
打开运行预处理的主程序代码Plan_and_preprocess_entrypoints.py,其中extract_fingerprint_entry函数通过argparse的方法传递超参数(键值传递)给Plan_and_preprocess_ap.py中的recursive_find_python_class函数动态递归查找指定的数据集dataname运行
请添加图片描述

,再执行当前文件下extract_fingerprint_dataset函数从【nnUNet_raw, dataset_name】文件路径获取数据。后通过调用fingerprint_extractor_class所指向的fingerprint_extractor.py中DatasetFingerprintExtractor函数从json文件中获取该数据集的fingerprint。

其中DatasetFingerprintExtractor函数包括

  1. collect_foreground_intensities静态方法,获取分割mask和图像数据,以及随机种子、采样个数、图像前景像素强度、均值等统计强度信息。
  2. Analyse_case静态方法接受文件列表、分割文件路径、图像读取器、对图像执行裁剪,run方法设置输出路径使用ptqdm运行analyse_case方法获取每个训练样本裁剪后的形状、间距、前景强度和统计性息并计算裁剪后图像各种信息的中位数将其保存在json文件中。

    这里记录了当前图片的空间分辨率、图像大小和其平均水平

请添加图片描述
这里是finger记录的json文件,它分析了样本统计学参数

2. plan_experiments

plan_experiments的目的是确定其模型的配置情况
程序调用规则:
plan_and_preprocess_api.py中包含plan_experiment_dataset和plan_experiments两个函数,plan_experiment_dataset用于配置,plan_experiments用于执行
调用nnunetv2.experiment_planning.experiment_planners.default_experiment_planner.ExperimentPlanner下的类ExperimentPlanner。

其中ExperimentPlanner类包括

  1. Determine_reader_writer方法从数据集中获取训练图像标识符,用于确定阅读器和编写器。
  2. Determine_reader_writer方法从数据集中获取训练图像标识符,用于确定阅读器和编写器。
  3. static_estimate_VRAM_usage方法determine_resampling方法用于确认数据和标签重采样方法。
  4. determine_segmentation_softmax_export_fn方法用于确定分割概率的导出函数。
  5. determine_fullres_target_spacing方法用于确定全分辨率(fullres)的目标间距。
  6. determine_normalization_scheme_and_whether_mask_is_used_for_norm方法用于确定归一化方案以及是否使用掩码进行归一化。
  7. determine_transpose方法用于确定数据是否需要转置。get_plans_for_configuration方法用于为给定配置生成计划,包括批处理大小、补丁大小、归一化方案等。
  8. plan_experiment用于规划实验,包括2D和3D配置的计划。
  9. save_plans方法用于将计划保存到文件中。
  10. generate_data_identifier方法用于区分来自不同计划的相同配置的数据。
  11. Load_plans方法用于从文件中加载计划。
    plan_experiments,通过递归寻找数据集路径,并执行上述plan_experiment_datase

给出配置情况,给出了
data_identifier:nnU-Net模型方法
preprocessor_name: 默认的预处理器
batchsize: 批大小
patchsize: 图像块大小
median_image_size_in_voxels:医学图像平均大小
spacing:空间分辨率
normalization_schemes:归一化策略
use_mask_for_norm:是否对mask归一化
UNet_class_name:使用的Unet类型
UNet_base_num_features:初始特征层维度
n_conv_per_stage_encoder,n_conv_per_stage_decoder:上、下采样
pool_op_kernel_sizes,conv_kernel_sizes:池化核与卷积核大小
unet_max_num_features:最大特征维度
esampling_fn_data、resampling_fn_probabilities、resampling_fn_seg:数据、概率、分割的重采样函数
batch_dice:是否使用批次Dice损失
next_stage_names、previous_stage_name:下一阶段和上一阶段名称

2D U-Net configuration
{‘data_identifier’: ‘nnUNetPlans_2d’, ‘preprocessor_name’: ‘DefaultPreprocessor’, ‘batch_size’: 16, ‘patch_size’: array([384, 512]), ‘median_image_size_in_voxels’: array([384., 512.]), ‘spacing’: array([0.703125, 0.703125]), ‘normalization_schemes’: [‘ZScoreNormalization’], ‘use_mask_for_norm’: [False], ‘UNet_class_name’: ‘PlainConvUNet’, ‘UNet_base_num_features’: 32, ‘n_conv_per_stage_encoder’: (2, 2, 2, 2, 2, 2, 2), ‘n_conv_per_stage_decoder’: (2, 2, 2, 2, 2, 2), ‘num_pool_per_axis’: [6, 6], ‘pool_op_kernel_sizes’: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]], ‘conv_kernel_sizes’: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], ‘unet_max_num_features’: 512, ‘resampling_fn_data’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_seg’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_data_kwargs’: {‘is_seg’: False, ‘order’: 3, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_seg_kwargs’: {‘is_seg’: True, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_probabilities’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_probabilities_kwargs’: {‘is_seg’: False, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘batch_dice’: True}

3D lowres U-Net configuration:
{‘data_identifier’: ‘nnUNetPlans_3d_lowres’, ‘preprocessor_name’: ‘DefaultPreprocessor’, ‘batch_size’: 2, ‘patch_size’: array([ 64, 160, 224]), ‘median_image_size_in_voxels’: [98, 261, 349], ‘spacing’: array([1.1748269 , 1.03256277, 1.03256277]), ‘normalization_schemes’: [‘ZScoreNormalization’], ‘use_mask_for_norm’: [False], ‘UNet_class_name’: ‘PlainConvUNet’, ‘UNet_base_num_features’: 32, ‘n_conv_per_stage_encoder’: (2, 2, 2, 2, 2, 2), ‘n_conv_per_stage_decoder’: (2, 2, 2, 2, 2), ‘num_pool_per_axis’: [4, 5, 5], ‘pool_op_kernel_sizes’: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], ‘conv_kernel_sizes’: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], ‘unet_max_num_features’: 320, ‘resampling_fn_data’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_seg’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_data_kwargs’: {‘is_seg’: False, ‘order’: 3, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_seg_kwargs’: {‘is_seg’: True, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_probabilities’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_probabilities_kwargs’: {‘is_seg’: False, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘batch_dice’: False, ‘next_stage’: ‘3d_cascade_fullres’}

3D fullres U-Net configuration:
{‘data_identifier’: ‘nnUNetPlans_3d_fullres’, ‘preprocessor_name’: ‘DefaultPreprocessor’, ‘batch_size’: 2, ‘patch_size’: array([ 64, 160, 224]), ‘median_image_size_in_voxels’: array([144., 384., 512.]), ‘spacing’: array([0.79999995, 0.703125 , 0.703125 ]), ‘normalization_schemes’: [‘ZScoreNormalization’], ‘use_mask_for_norm’: [False], ‘UNet_class_name’: ‘PlainConvUNet’, ‘UNet_base_num_features’: 32, ‘n_conv_per_stage_encoder’: (2, 2, 2, 2, 2, 2), ‘n_conv_per_stage_decoder’: (2, 2, 2, 2, 2), ‘num_pool_per_axis’: [4, 5, 5], ‘pool_op_kernel_sizes’: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], ‘conv_kernel_sizes’: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], ‘unet_max_num_features’: 320, ‘resampling_fn_data’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_seg’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_data_kwargs’: {‘is_seg’: False, ‘order’: 3, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_seg_kwargs’: {‘is_seg’: True, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_probabilities’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_probabilities_kwargs’: {‘is_seg’: False, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘batch_dice’: True}

3. Preprocess

nnunetv2.utilities.plans_handling.plans_handler.ConfigurationManager下preprocessor函数实现了对上述配置的批量化预处理包括转置、裁剪、重采样

裁剪:
图像裁剪就是将三维的医学图像裁剪到它的非零区域,具体方法就是在图像中寻找一个最小的三维bounding box,该bounding box区域以外的值为0,使用这个bounding box对图像进行裁剪。
请添加图片描述
重采样:
使用数据集各个图像不同spacing的中值,根据target_spacing确定每张图像的目标尺寸。每张图像, spacing和shape之间的乘积为一个定值,代表整个图像在实际空间中的大小。调用skimage库中的reisze函数对每张图像进行resize即可

4. Dataloader

dataloader分为2d和3d两个py文件
2D和3D的DataLoader都继承自batchgenerators.dataloading.data_loader中Dataloader

nnUNetDataLoader2D首先获取finger中指定的slice切片数,生成一个训练数据批次。它选择从 2D 图像中获取特定切片以形成批次。
nnUNetDataLoader3D,生成一个训练数据批次。它选择从 3D 图像中获取特定切片以形成批次

注意!!!
常规的一个epoch的定义就是把dataloader里所有的batch跑完一遍,一个batch中包含n个样本,通常认知里每个样本分别来自一幅图像,那么也就是说一个epoch就是把数据集里的所有图像跑完一遍,但nnunet不同,它的样本可能来自同一幅图像,也就是说同样的batch数,nnunet不一定能完整的跑完整个数据集里的每一幅图像,对于“完整”这个概念,作者在github上是这么解释的:由于在3D图像中一定会切割patch,那么首先无论如何都做不到把原来图像中的每个patch都裁出来进行训练,换句话说首先对于单个图像就没有“完整”训练的概念,那么强调在一个epoch中“完整”遍历每一张图像也就没有意义,所以nnunet的一个epoch不是跑完整个数据集结束,而是跑完指定的batch数结束。

5. data_argument

nnUnet通过batchgenerators库中的数据增强策略实现在线数据增强,如果需要修改数据增强的策略可以在nnunetv2.training.nnUNetTrainer.nnUNetTrainer.nnUNetTrainer.get_training_transforms这个函数中对数据增强的方法增添和删减。
作者在原始代码中使用的强度数据增强策略为:高斯噪声、模糊、亮度变换、反转图像、对比度增强、随机降采样、连续的两次gamma变换等等,另外还包含空间变换的方法镜像变换、随机旋转、缩放、平移,维度变换的方法针对二进制操作符的随机应用、随机移除一部分与分割标签连接、标签变换的方法对数据和标签的处理、使用掩码进行数据归一化
若想要修改增强策略,具体可根据batchgenerators库数据增强方法,将其添加到nnUnet的get_training_transforms中列表中

training

6. initialize

nnU-Net的初始化用于构建网络结构、配置优化器策略及参数、设置损失函数上述构建的方式均通过先前保存的json文件中获取配置情况

nnU-Net的基础配置的网络结构共有4种,2d,3d,3d_low,3d_cas,其中级联网络需要先在3d_low上训练。

优化器使用SGD,默认学习率衰减策略为线性衰减

损失函数,如果是基于区域的分割则使用Dice和BCE损失结合的方式,否则使用Dice和CE损失,其中Dice使用MemoryEfficientSoftDiceLoss

7-8. train one epoch and validation

同常规神经网络,nnU-Net会把主要代码包装在函数中,通过self调用的方式覆盖

9.loss

inference

10.predict_from_raw_data

nnU-Net v2 是 nnU-Net 框架的升级版本,它在医学图像分割任务中表现出了更强的适应性和更高的精度。该框架通过自动化的模型配置和训练流程优化了传统 U-Net 的性能[^1]。 ### 网络架构详解 nnU-Net v2 的核心改进在于其模块化的设计以及对不同数据集特性的自适应能力。与原始 U-Net 相比,nnU-Net v2 引入了多种新的组件和技术来提升分割效果,包括但不限于动态网络架构选择、残差连接、注意力机制等。 #### 编码器-解码器结构 nnU-Net v2 保持了经典的编码器-解码器结构,其中编码器负责提取特征,而解码器则尝试恢复空间信息以生成最终的分割图。这种设计允许模型学习从输入到输出的复杂映射关系[^2]。 #### 动态网络架构 一个关键的变化是使用了动态网络架构,例如 CAPlainConvUNet 或者其他变种,这些架构可以根据具体的任务需求进行调整,从而达到最佳性能。这通常涉及到卷积层的数量、大小以及其他超参数的选择[^1]。 #### 残差连接 为了缓解深度神经网络中的梯度消失问题,并促进信息流动,nnU-Net v2 在其编码器部分加入了残差块(Residual Blocks)。这样的设计有助于训练更深层次的模型并提高准确性[^2]。 #### 注意力机制 某些版本的 nnU-Net v2 可能还会集成注意力机制,比如SE Block (Squeeze-and-Excitation block),用以增强重要特征的同时抑制不相关的背景噪声,进一步提高了分割质量。 ### 原理图 虽然无法直接提供原理图,但可以描述典型的 nnU-Net v2 架构布局如下: ``` Input Layer | [Convolutional Layers with Residual Connections] | [Bottleneck Layer] | [Transposed Convolutional Layers for Upsampling] | Output Segmentation Map ``` ### 模型结构示例代码 以下是一个简化的 Python 示例,展示如何构建一个基于 PyTorch 的 nnU-Net v2 风格的基本模型: ```python import torch from torch import nn class SimpleResBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) self.norm1 = nn.InstanceNorm3d(in_channels) self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) self.norm2 = nn.InstanceNorm3d(in_channels) def forward(self, x): residual = x x = F.relu(self.norm1(self.conv1(x))) x = self.norm2(self.conv2(x)) x += residual return F.relu(x) # 创建一个简单的nnU-Netv2风格的模型实例 model = nn.Sequential( # Encoder SimpleResBlock(1), # 输入通道数为1,假设处理的是灰度图像 # ... 添加更多层 ... # Bottleneck nn.Conv3d(64, 128, kernel_size=3, padding=1), # Decoder # ... 反卷积层和其他上采样操作 ... # Output layer nn.Conv3d(64, num_classes, kernel_size=1) # num_classes代表不同的分割类别数目 ) ``` 请注意,上述代码仅作为一个基础示例,并未包含所有可能的功能或优化选项。实际应用中,您需要根据具体需求定制网络细节,并参考官方文档或者相关论文获取最新的实现方法和技术细节。
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

supernova121

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值