推荐环境
————————
Ubuntu22.04
CUDA 12.4
torch 2.6.0
nnunetv2 2.6.0
mamba-ssm 2.2.4 (pip install mamba-ssm, 编译时间长)
mmcv 2.1.0
mmsegmentation 1.2.2
mmengine 0.10.6
causal-conv1d 1.5.0.post8
————————
Windows下尚未测试合适的环境,部分Trainer无法使用
torch2.6.0目前无法安装mmcv和mmseg,对应的部分模型无法使用
mamba-ssm和causal-conv1d如未安装,对应的部分模型无法使用
环境搭建
conda create -n nnunet_extend python=3.10 -y
conda activate nnunet_extend
# 安装 torch (注意安装与驱动对应的cuda版本)
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu12.4
# 安装nnunet_extend
git clone https://gitee.com/Eason596/nnunet_extend.git
cd nnunet_extend
pip install -e .
使用
训练
# 参数与nnUNetv2_train保持一致
# cfg_type
# 2d
# 3d_lowres
# 3d_fullres
nnUNetv2_extend_train num_of_dataset cfg_type num_of_fold -tr nnUNetTrainer -p nnUNetPlans
推理
# 新增-pr参数, 其他参数与nnUNetv2_predict保持一致
# step_size default 0.5. < 1
# disable_tta 关闭测试时数据增强
# continue_prediction 继续未完成的推理
# pr 指定predictor 默认为nnUNetPredictor 新增nnUNetPredictorBinary输出二值图时*255
nnUNetv2_extend_predict -i input_folder -o output_folder -d num_of_dataset -c model_type -f num_of_fold -step_size 0.5 --disable_tta --verbose --continue_prediction -tr nnUNetTrainer -p nnUNetPlans -pr nnUNetPredictorBinary
评估
coming soon…
自定义模型
custom类
-
以SwinUNETR为例,在nnunet_extend/network_architecture/monai_model.py中导入模型(也可以是其他自定义架构)
# nnunet_extend/network_architecture/monai_model.py from monai.networks.nets import UNETR, SwinUNETR, UNet, VNet, SegResNet, AttentionUnet -
添加nnunet_extend/training/trainers/SwinUNETRTrainer.py文件
from nnunet_extend.network_architecture.monai_model import SwinUNETR from torch import nn from typing import List, Tuple, Union from nnunet_extend.registry.registry import TRAINERS from nnunet_extend.training.trainers.loss import nnUNetTrainerNoDeepSupervisionNoDecoder, nnUNetTrainerCustomDeepSupervision from nnunet_extend.training.trainers.base import lr1e4Trainer @TRAINERS.register_module() class SwinUNETR3D128Trainer(nnUNetTrainerNoDeepSupervisionNoDecoder, lr1e4Trainer): @staticmethod def build_network_architecture(architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: return SwinUNETR( img_size = 128, in_channels = num_input_channels, out_channels = num_output_channels, depths = (2, 2, 2, 2), num_heads = (3, 6, 12, 24), feature_size = 24, norm_name = "instance", drop_rate = 0.0, attn_drop_rate = 0.0, dropout_path_rate = 0.0, normalize = True, use_checkpoint = False, spatial_dims = 3, downsample = "merging", use_v2 = False )[!Note]
如果模型只有一个输出,继承nnUNetTrainerNoDeepSupervisionNoDecoder
如果模型有多个输出,可以用到DeepSupervision功能,继承nnUNetTrainerCustomDeepSupervision
如果需要调整学习率,也可以自定学习率的Trainer并继承
-
导入Trainer
# 在nnunet_extend/training/trainers/__init__.py中导入Trainer # ... from .SwinUNETRTrainer import SwinUNETR3D128Trainer # ...
nnunet类
参考umamba
mmseg类
-
参考CCNet,复制mmseg中对应模型的配置https://gitee.com/open-mmlab/mmsegmentation/tree/main/configs/base/models/ccnet_r50-d8.py
CCNet ├── configs ├── ccnet_r101-d8.py ├── ccnet_r50-d8.py ├── model.py# ccnet_r50-d8.py # ... # 注释掉data_preprocessor model = dict( type='EncoderDecoder', # data_preprocessor=data_preprocessor, # ... )[!NOte]
使用mmseg自定义backbone时要求backbone需要有in_channels和num_classes参数
-
在model.py的代码测试模型是否可以运行
from mmengine.config import Config from nnunet_extend.network_architecture.mmseg import MMSegModel import nnunet_extend import os import torch if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' cfg = Config.fromfile(os.path.join(nnunet_extend.__path__[0], 'network_architecture', 'CCNet', 'configs', 'ccnet_r50-d8.py')) # cfg = Config.fromfile(os.path.join(nnunet_extend.__path__[0], 'network_architecture', 'CCNet', 'configs', 'ccnet_r101-d8.py')) model = MMSegModel(3, 1, cfg, 0, False).to(device) x = torch.randn(4, 3, 512, 512).to(device) print(model(x).shape) -
添加nnunet_extend/training/trainers/CCNetTrainer.py文件
from nnunet_extend.network_architecture.mmseg import MMSegModel from nnunet_extend.training.trainers.base import lr6e5Trainer from nnunet_extend.training.trainers.loss import nnUNetTrainerNoDeepSupervisionNoDecoder, nnUNetTrainerCustomDeepSupervision from torch import nn from typing import List, Tuple, Union from nnunet_extend.registry.registry import TRAINERS from mmengine.config import Config import nnunet_extend import os @TRAINERS.register_module() class CCNet_Res50_d8Trainer(nnUNetTrainerCustomDeepSupervision): @staticmethod def build_network_architecture(architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: return MMSegModel( input_channels=num_input_channels, num_classes=num_output_channels, cfg=Config.fromfile(os.path.join(nnunet_extend.__path__[0], 'network_architecture', 'CCNet', 'configs', 'ccnet_r50-d8.py')), final_up=8, deep_supervision=enable_deep_supervision ) @TRAINERS.register_module() class CCNet_Res101_d8Trainer(nnUNetTrainerCustomDeepSupervision): @staticmethod def build_network_architecture(architecture_class_name: str, arch_init_kwargs: dict, arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]], num_input_channels: int, num_output_channels: int, enable_deep_supervision: bool = True) -> nn.Module: return MMSegModel( input_channels=num_input_channels, num_classes=num_output_channels, cfg=Config.fromfile(os.path.join(nnunet_extend.__path__[0], 'network_architecture', 'CCNet', 'configs', 'ccnet_r50-d8.py')), final_up=8, deep_supervision=enable_deep_supervision ) -
导入Trainer
# 在nnunet_extend/training/trainers/__init__.py中导入Trainer # ... from .CCNetTrainer import CCNet_Res50_d8Trainer, CCNet_Res101_d8Trainer # ...
自定义损失
-
以clDiceLoss为例,添加nnunet_extend/training/loss/clDiceLoss.py文件, 实现相应的损失
-
添加对应的Trainer,nnunet_extend/training/trainers/loss/clDiceTrainer.py
-
导入Trainer
# 在nnunet_extend/training/trainers/loss/__init__.py中导入Trainer # ... from .clDiceTrainer import SoftclDiceTrainer # ...
组合模型与损失
# 案例,在nnunet_extend/training/trainers/SwinUNETRTrainer.py中添加SwinUNETR3D128_CompoundclDiceTrainer
from nnunet_extend.network_architecture.monai_model import SwinUNETR
from torch import nn
from typing import List, Tuple, Union
from nnunet_extend.registry.registry import TRAINERS
from nnunet_extend.training.trainers.loss import nnUNetTrainerNoDeepSupervisionNoDecoder, CompoundclDiceTrainer
from nnunet_extend.training.trainers.base import lr1e4Trainer
@TRAINERS.register_module()
class SwinUNETR3D128Trainer(nnUNetTrainerNoDeepSupervisionNoDecoder, lr1e4Trainer):
...
# 组合不同的模型,损失函数、学习率等
@TRAINERS.register_module()
class SwinUNETR3D128_CompoundclDiceTrainer(CompoundclDiceTrainer, SwinUNETR3D128Trainer):
pass
4106

被折叠的 条评论
为什么被折叠?



