提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
nnU-Net源码剖写解读
前言
nnU-Net已经称为了医学图像分割的backbone是现有医学分割模型的集大成者,它已其简单的训练流程和简单U-net模型就可达到sota的水平。我们将对nnUnetV2版本的源码进行剖析,理解和分析这个框架的原理,并希望在学习的后期能够修改源码作出创新。
nnU-Net分析结构
我们知道nnU-NetV2的操作主要分为3部分
- 预处理,将图像按照自适应方法处理至统一标准
- 训练,对网络进行训练
- 推理,通过训练权重预测标签
nnUNetV2的结构
plan_and_preprocess_entry()
P1. extract_fingerprints
P2. plan_experiments
P3. preprocess
cropping
normalization
preprocessors
resampling
training
P4. dataloader
base
2d
3d
nnunet_dataset
P5. data_augmentation
compute_initial_patch_size
custom_transforms
nnUNetTrainer
P6. initialize
including network architectures
P7. train one epoch
P8. validation
P9.loss
inference
P10. predict_from_raw_data
P11. slideing window
P12. export prediction
P13. find best configuration
P14. postprocessing
dynamic network
P15. building blocks + resnet
P16. building blocks + vgg
二、分布剖析
plan_and_preprocess_entry()

1. extract_fingerprints_entry
extract_fingerprints的目的是获取数据集的“指纹”即数据集的统计学参数。
打开运行预处理的主程序代码Plan_and_preprocess_entrypoints.py,其中extract_fingerprint_entry函数通过argparse的方法传递超参数(键值传递)给Plan_and_preprocess_ap.py中的recursive_find_python_class函数动态递归查找指定的数据集dataname运行

,再执行当前文件下extract_fingerprint_dataset函数从【nnUNet_raw, dataset_name】文件路径获取数据。后通过调用fingerprint_extractor_class所指向的fingerprint_extractor.py中DatasetFingerprintExtractor函数从json文件中获取该数据集的fingerprint。
其中DatasetFingerprintExtractor函数包括
- collect_foreground_intensities静态方法,获取分割mask和图像数据,以及随机种子、采样个数、图像前景像素强度、均值等统计强度信息。
- Analyse_case静态方法接受文件列表、分割文件路径、图像读取器、对图像执行裁剪,run方法设置输出路径使用ptqdm运行analyse_case方法获取每个训练样本裁剪后的形状、间距、前景强度和统计性息并计算裁剪后图像各种信息的中位数将其保存在json文件中。

这里记录了当前图片的空间分辨率、图像大小和其平均水平

这里是finger记录的json文件,它分析了样本统计学参数
2. plan_experiments
plan_experiments的目的是确定其模型的配置情况
程序调用规则:
plan_and_preprocess_api.py中包含plan_experiment_dataset和plan_experiments两个函数,plan_experiment_dataset用于配置,plan_experiments用于执行
调用nnunetv2.experiment_planning.experiment_planners.default_experiment_planner.ExperimentPlanner下的类ExperimentPlanner。
其中ExperimentPlanner类包括
- Determine_reader_writer方法从数据集中获取训练图像标识符,用于确定阅读器和编写器。
- Determine_reader_writer方法从数据集中获取训练图像标识符,用于确定阅读器和编写器。
- static_estimate_VRAM_usage方法determine_resampling方法用于确认数据和标签重采样方法。
- determine_segmentation_softmax_export_fn方法用于确定分割概率的导出函数。
- determine_fullres_target_spacing方法用于确定全分辨率(fullres)的目标间距。
- determine_normalization_scheme_and_whether_mask_is_used_for_norm方法用于确定归一化方案以及是否使用掩码进行归一化。
- determine_transpose方法用于确定数据是否需要转置。get_plans_for_configuration方法用于为给定配置生成计划,包括批处理大小、补丁大小、归一化方案等。
- plan_experiment用于规划实验,包括2D和3D配置的计划。
- save_plans方法用于将计划保存到文件中。
- generate_data_identifier方法用于区分来自不同计划的相同配置的数据。
- Load_plans方法用于从文件中加载计划。
plan_experiments,通过递归寻找数据集路径,并执行上述plan_experiment_datase
给出配置情况,给出了
data_identifier:nnU-Net模型方法
preprocessor_name: 默认的预处理器
batchsize: 批大小
patchsize: 图像块大小
median_image_size_in_voxels:医学图像平均大小
spacing:空间分辨率
normalization_schemes:归一化策略
use_mask_for_norm:是否对mask归一化
UNet_class_name:使用的Unet类型
UNet_base_num_features:初始特征层维度
n_conv_per_stage_encoder,n_conv_per_stage_decoder:上、下采样
pool_op_kernel_sizes,conv_kernel_sizes:池化核与卷积核大小
unet_max_num_features:最大特征维度
esampling_fn_data、resampling_fn_probabilities、resampling_fn_seg:数据、概率、分割的重采样函数
batch_dice:是否使用批次Dice损失
next_stage_names、previous_stage_name:下一阶段和上一阶段名称
2D U-Net configuration
{‘data_identifier’: ‘nnUNetPlans_2d’, ‘preprocessor_name’: ‘DefaultPreprocessor’, ‘batch_size’: 16, ‘patch_size’: array([384, 512]), ‘median_image_size_in_voxels’: array([384., 512.]), ‘spacing’: array([0.703125, 0.703125]), ‘normalization_schemes’: [‘ZScoreNormalization’], ‘use_mask_for_norm’: [False], ‘UNet_class_name’: ‘PlainConvUNet’, ‘UNet_base_num_features’: 32, ‘n_conv_per_stage_encoder’: (2, 2, 2, 2, 2, 2, 2), ‘n_conv_per_stage_decoder’: (2, 2, 2, 2, 2, 2), ‘num_pool_per_axis’: [6, 6], ‘pool_op_kernel_sizes’: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]], ‘conv_kernel_sizes’: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], ‘unet_max_num_features’: 512, ‘resampling_fn_data’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_seg’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_data_kwargs’: {‘is_seg’: False, ‘order’: 3, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_seg_kwargs’: {‘is_seg’: True, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_probabilities’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_probabilities_kwargs’: {‘is_seg’: False, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘batch_dice’: True}
3D lowres U-Net configuration:
{‘data_identifier’: ‘nnUNetPlans_3d_lowres’, ‘preprocessor_name’: ‘DefaultPreprocessor’, ‘batch_size’: 2, ‘patch_size’: array([ 64, 160, 224]), ‘median_image_size_in_voxels’: [98, 261, 349], ‘spacing’: array([1.1748269 , 1.03256277, 1.03256277]), ‘normalization_schemes’: [‘ZScoreNormalization’], ‘use_mask_for_norm’: [False], ‘UNet_class_name’: ‘PlainConvUNet’, ‘UNet_base_num_features’: 32, ‘n_conv_per_stage_encoder’: (2, 2, 2, 2, 2, 2), ‘n_conv_per_stage_decoder’: (2, 2, 2, 2, 2), ‘num_pool_per_axis’: [4, 5, 5], ‘pool_op_kernel_sizes’: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], ‘conv_kernel_sizes’: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], ‘unet_max_num_features’: 320, ‘resampling_fn_data’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_seg’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_data_kwargs’: {‘is_seg’: False, ‘order’: 3, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_seg_kwargs’: {‘is_seg’: True, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_probabilities’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_probabilities_kwargs’: {‘is_seg’: False, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘batch_dice’: False, ‘next_stage’: ‘3d_cascade_fullres’}
3D fullres U-Net configuration:
{‘data_identifier’: ‘nnUNetPlans_3d_fullres’, ‘preprocessor_name’: ‘DefaultPreprocessor’, ‘batch_size’: 2, ‘patch_size’: array([ 64, 160, 224]), ‘median_image_size_in_voxels’: array([144., 384., 512.]), ‘spacing’: array([0.79999995, 0.703125 , 0.703125 ]), ‘normalization_schemes’: [‘ZScoreNormalization’], ‘use_mask_for_norm’: [False], ‘UNet_class_name’: ‘PlainConvUNet’, ‘UNet_base_num_features’: 32, ‘n_conv_per_stage_encoder’: (2, 2, 2, 2, 2, 2), ‘n_conv_per_stage_decoder’: (2, 2, 2, 2, 2), ‘num_pool_per_axis’: [4, 5, 5], ‘pool_op_kernel_sizes’: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], ‘conv_kernel_sizes’: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], ‘unet_max_num_features’: 320, ‘resampling_fn_data’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_seg’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_data_kwargs’: {‘is_seg’: False, ‘order’: 3, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_seg_kwargs’: {‘is_seg’: True, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘resampling_fn_probabilities’: ‘resample_data_or_seg_to_shape’, ‘resampling_fn_probabilities_kwargs’: {‘is_seg’: False, ‘order’: 1, ‘order_z’: 0, ‘force_separate_z’: None}, ‘batch_dice’: True}
3. Preprocess
nnunetv2.utilities.plans_handling.plans_handler.ConfigurationManager下preprocessor函数实现了对上述配置的批量化预处理包括转置、裁剪、重采样
裁剪:
图像裁剪就是将三维的医学图像裁剪到它的非零区域,具体方法就是在图像中寻找一个最小的三维bounding box,该bounding box区域以外的值为0,使用这个bounding box对图像进行裁剪。

重采样:
使用数据集各个图像不同spacing的中值,根据target_spacing确定每张图像的目标尺寸。每张图像, spacing和shape之间的乘积为一个定值,代表整个图像在实际空间中的大小。调用skimage库中的reisze函数对每张图像进行resize即可
4. Dataloader
dataloader分为2d和3d两个py文件
2D和3D的DataLoader都继承自batchgenerators.dataloading.data_loader中Dataloader
nnUNetDataLoader2D首先获取finger中指定的slice切片数,生成一个训练数据批次。它选择从 2D 图像中获取特定切片以形成批次。
nnUNetDataLoader3D,生成一个训练数据批次。它选择从 3D 图像中获取特定切片以形成批次
注意!!!
常规的一个epoch的定义就是把dataloader里所有的batch跑完一遍,一个batch中包含n个样本,通常认知里每个样本分别来自一幅图像,那么也就是说一个epoch就是把数据集里的所有图像跑完一遍,但nnunet不同,它的样本可能来自同一幅图像,也就是说同样的batch数,nnunet不一定能完整的跑完整个数据集里的每一幅图像,对于“完整”这个概念,作者在github上是这么解释的:由于在3D图像中一定会切割patch,那么首先无论如何都做不到把原来图像中的每个patch都裁出来进行训练,换句话说首先对于单个图像就没有“完整”训练的概念,那么强调在一个epoch中“完整”遍历每一张图像也就没有意义,所以nnunet的一个epoch不是跑完整个数据集结束,而是跑完指定的batch数结束。
5. data_argument
nnUnet通过batchgenerators库中的数据增强策略实现在线数据增强,如果需要修改数据增强的策略可以在nnunetv2.training.nnUNetTrainer.nnUNetTrainer.nnUNetTrainer.get_training_transforms这个函数中对数据增强的方法增添和删减。
作者在原始代码中使用的强度数据增强策略为:高斯噪声、模糊、亮度变换、反转图像、对比度增强、随机降采样、连续的两次gamma变换等等,另外还包含空间变换的方法镜像变换、随机旋转、缩放、平移,维度变换的方法针对二进制操作符的随机应用、随机移除一部分与分割标签连接、标签变换的方法对数据和标签的处理、使用掩码进行数据归一化
若想要修改增强策略,具体可根据batchgenerators库数据增强方法,将其添加到nnUnet的get_training_transforms中列表中
training
6. initialize
nnU-Net的初始化用于构建网络结构、配置优化器策略及参数、设置损失函数上述构建的方式均通过先前保存的json文件中获取配置情况
nnU-Net的基础配置的网络结构共有4种,2d,3d,3d_low,3d_cas,其中级联网络需要先在3d_low上训练。
优化器使用SGD,默认学习率衰减策略为线性衰减
损失函数,如果是基于区域的分割则使用Dice和BCE损失结合的方式,否则使用Dice和CE损失,其中Dice使用MemoryEfficientSoftDiceLoss
7-8. train one epoch and validation
同常规神经网络,nnU-Net会把主要代码包装在函数中,通过self调用的方式覆盖
7305





