nnunet扩展:nnunet_extend

部署运行你感兴趣的模型镜像

推荐环境
————————
Ubuntu22.04
CUDA 12.4
torch 2.6.0
nnunetv2 2.6.0
mamba-ssm 2.2.4 (pip install mamba-ssm, 编译时间长)
mmcv 2.1.0
mmsegmentation 1.2.2
mmengine 0.10.6
causal-conv1d 1.5.0.post8
————————
Windows下尚未测试合适的环境,部分Trainer无法使用
torch2.6.0目前无法安装mmcv和mmseg,对应的部分模型无法使用
mamba-ssm和causal-conv1d如未安装,对应的部分模型无法使用

环境搭建

conda create -n nnunet_extend python=3.10 -y
conda activate nnunet_extend

# 安装 torch (注意安装与驱动对应的cuda版本)
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu12.4

# 安装nnunet_extend
git clone https://gitee.com/Eason596/nnunet_extend.git
cd nnunet_extend
pip install -e .

使用

训练

# 参数与nnUNetv2_train保持一致
# cfg_type
# 2d
# 3d_lowres
# 3d_fullres
nnUNetv2_extend_train num_of_dataset cfg_type num_of_fold -tr nnUNetTrainer -p nnUNetPlans

推理

# 新增-pr参数, 其他参数与nnUNetv2_predict保持一致
# step_size default 0.5. < 1 
# disable_tta 关闭测试时数据增强
# continue_prediction 继续未完成的推理
# pr 指定predictor 默认为nnUNetPredictor 新增nnUNetPredictorBinary输出二值图时*255
nnUNetv2_extend_predict -i input_folder -o output_folder -d num_of_dataset -c model_type -f num_of_fold -step_size 0.5 --disable_tta --verbose --continue_prediction -tr nnUNetTrainer -p nnUNetPlans -pr nnUNetPredictorBinary

评估

coming soon…

自定义模型

custom类

  1. 以SwinUNETR为例,在nnunet_extend/network_architecture/monai_model.py中导入模型(也可以是其他自定义架构)

    # nnunet_extend/network_architecture/monai_model.py
    from monai.networks.nets import UNETR, SwinUNETR, UNet, VNet, SegResNet, AttentionUnet
    
  2. 添加nnunet_extend/training/trainers/SwinUNETRTrainer.py文件

    from nnunet_extend.network_architecture.monai_model import SwinUNETR
    from torch import nn
    from typing import List, Tuple, Union
    from nnunet_extend.registry.registry import TRAINERS
    from nnunet_extend.training.trainers.loss import nnUNetTrainerNoDeepSupervisionNoDecoder, nnUNetTrainerCustomDeepSupervision
    from nnunet_extend.training.trainers.base import lr1e4Trainer
    
    @TRAINERS.register_module()
    class SwinUNETR3D128Trainer(nnUNetTrainerNoDeepSupervisionNoDecoder, lr1e4Trainer):
        @staticmethod
        def build_network_architecture(architecture_class_name: str,
                                       arch_init_kwargs: dict,
                                       arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
                                       num_input_channels: int,
                                       num_output_channels: int,
                                       enable_deep_supervision: bool = True) -> nn.Module:
    
    
            return SwinUNETR(
                img_size = 128,
                in_channels = num_input_channels,
                out_channels = num_output_channels,
                depths = (2, 2, 2, 2),
                num_heads = (3, 6, 12, 24),
                feature_size = 24,
                norm_name = "instance",
                drop_rate = 0.0,
                attn_drop_rate = 0.0,
                dropout_path_rate = 0.0,
                normalize = True,
                use_checkpoint = False,
                spatial_dims = 3,
                downsample = "merging",
                use_v2 = False
            )
    

    [!Note]

    如果模型只有一个输出,继承nnUNetTrainerNoDeepSupervisionNoDecoder

    如果模型有多个输出,可以用到DeepSupervision功能,继承nnUNetTrainerCustomDeepSupervision

    如果需要调整学习率,也可以自定学习率的Trainer并继承

  3. 导入Trainer

    # 在nnunet_extend/training/trainers/__init__.py中导入Trainer
    # ...
    from .SwinUNETRTrainer import SwinUNETR3D128Trainer
    # ...
    

nnunet类

参考umamba

mmseg类

  1. 参考CCNet,复制mmseg中对应模型的配置https://gitee.com/open-mmlab/mmsegmentation/tree/main/configs/base/models/ccnet_r50-d8.py

    CCNet
        ├── configs 
            ├── ccnet_r101-d8.py
            ├── ccnet_r50-d8.py
        ├── model.py
    
    # ccnet_r50-d8.py
    # ...
    # 注释掉data_preprocessor
    model = dict(
        type='EncoderDecoder',
        # data_preprocessor=data_preprocessor,
        # ...
    )
    

    [!NOte]

    使用mmseg自定义backbone时要求backbone需要有in_channels和num_classes参数

  2. 在model.py的代码测试模型是否可以运行

    from mmengine.config import Config
    from nnunet_extend.network_architecture.mmseg import MMSegModel
    import nnunet_extend
    import os
    import torch
    
    if __name__ == '__main__':
        device = 'cuda' if torch.cuda.is_available() else 'cpu' 
        cfg = Config.fromfile(os.path.join(nnunet_extend.__path__[0], 'network_architecture', 'CCNet', 'configs', 'ccnet_r50-d8.py'))
        # cfg = Config.fromfile(os.path.join(nnunet_extend.__path__[0], 'network_architecture', 'CCNet', 'configs', 'ccnet_r101-d8.py'))
        model = MMSegModel(3, 1, cfg, 0, False).to(device)
        x = torch.randn(4, 3, 512, 512).to(device)
        print(model(x).shape)
    
  3. 添加nnunet_extend/training/trainers/CCNetTrainer.py文件

    from nnunet_extend.network_architecture.mmseg import MMSegModel
    from nnunet_extend.training.trainers.base import lr6e5Trainer
    from nnunet_extend.training.trainers.loss import nnUNetTrainerNoDeepSupervisionNoDecoder, nnUNetTrainerCustomDeepSupervision
    from torch import nn
    from typing import List, Tuple, Union
    from nnunet_extend.registry.registry import TRAINERS
    from mmengine.config import Config
    import nnunet_extend
    import os
    
    
    @TRAINERS.register_module()
    class CCNet_Res50_d8Trainer(nnUNetTrainerCustomDeepSupervision):
        @staticmethod
        def build_network_architecture(architecture_class_name: str,
                                       arch_init_kwargs: dict,
                                       arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
                                       num_input_channels: int,
                                       num_output_channels: int,
                                       enable_deep_supervision: bool = True) -> nn.Module:
    
    
            return MMSegModel(
                input_channels=num_input_channels,
                num_classes=num_output_channels,
                cfg=Config.fromfile(os.path.join(nnunet_extend.__path__[0], 'network_architecture', 'CCNet', 'configs', 'ccnet_r50-d8.py')),
                final_up=8,
                deep_supervision=enable_deep_supervision
            )
    
    @TRAINERS.register_module()
    class CCNet_Res101_d8Trainer(nnUNetTrainerCustomDeepSupervision):
        @staticmethod
        def build_network_architecture(architecture_class_name: str,
                                       arch_init_kwargs: dict,
                                       arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
                                       num_input_channels: int,
                                       num_output_channels: int,
                                       enable_deep_supervision: bool = True) -> nn.Module:
    
    
            return MMSegModel(
                input_channels=num_input_channels,
                num_classes=num_output_channels,
                cfg=Config.fromfile(os.path.join(nnunet_extend.__path__[0], 'network_architecture', 'CCNet', 'configs', 'ccnet_r50-d8.py')),
                final_up=8,
                deep_supervision=enable_deep_supervision
            )
    
  4. 导入Trainer

    # 在nnunet_extend/training/trainers/__init__.py中导入Trainer
    # ...
    from .CCNetTrainer import CCNet_Res50_d8Trainer, CCNet_Res101_d8Trainer
    # ...
    

自定义损失

  1. 以clDiceLoss为例,添加nnunet_extend/training/loss/clDiceLoss.py文件, 实现相应的损失

  2. 添加对应的Trainer,nnunet_extend/training/trainers/loss/clDiceTrainer.py

  3. 导入Trainer

    # 在nnunet_extend/training/trainers/loss/__init__.py中导入Trainer
    # ...
    from .clDiceTrainer import SoftclDiceTrainer
    # ...
    

组合模型与损失

# 案例,在nnunet_extend/training/trainers/SwinUNETRTrainer.py中添加SwinUNETR3D128_CompoundclDiceTrainer
from nnunet_extend.network_architecture.monai_model import SwinUNETR
from torch import nn
from typing import List, Tuple, Union
from nnunet_extend.registry.registry import TRAINERS
from nnunet_extend.training.trainers.loss import nnUNetTrainerNoDeepSupervisionNoDecoder, CompoundclDiceTrainer
from nnunet_extend.training.trainers.base import lr1e4Trainer

@TRAINERS.register_module()
class SwinUNETR3D128Trainer(nnUNetTrainerNoDeepSupervisionNoDecoder, lr1e4Trainer):
    ...

# 组合不同的模型,损失函数、学习率等
@TRAINERS.register_module()
class SwinUNETR3D128_CompoundclDiceTrainer(CompoundclDiceTrainer, SwinUNETR3D128Trainer):
	pass
    

您可能感兴趣的与本文相关的镜像

Seed-Coder-8B-Base

Seed-Coder-8B-Base

文本生成
Seed-Coder

Seed-Coder是一个功能强大、透明、参数高效的 8B 级开源代码模型系列,包括基础变体、指导变体和推理变体,由字节团队开源

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

热度__

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值