【基于nnunet的深度学习医学图像分割网络的使用经验系列】_其一_ResEnc nnU-Net

前言

nnU-Net已经成为了很多医学图像分割网络的baseline。在nnU-Net的开源精神的指引下,一大批成熟的医学图像分割网络也纷纷开源,这对于医学图像分割领域的科研工作者是极大地利好。考虑到国内还没有成熟的nnU-Net的使用踩坑记录,本系列博客应运而生。以此记录我在使用nnU-Net和基于nnU-Net框架的各种开源医学图像分割网络的使用记录。

准备工作

在具备能跑深度学习的NVIDIA显卡的计算机上,安装Ubuntu系统,安装NVIDIA驱动,安装cuda,安装cudnn,安装anaconda,安装vscode,完成基础的环境配置。(这部分内容请自行搜索)

也可以考虑租用云服务器,具体有关于云服务器炼丹的方法请自行搜索。此时,就不必创建虚拟环境了,在配置好cuda的主机上启动后,直接进行准备工作的第三步,安装nnU-Net v2即可

本文讲述nnU-Netv2的使用,基于nnU-Net框架的各种开源医学图像分割网络的使用请期待后文。

  1. 创建虚拟环境
conda create -n new_env python=3.10
  1. 激活虚拟环境
conda activate new_env
  1. 安装nnU-Net v2
pip install nnunetv2

安装nnunetv2时,pytorch等相关包会自动安装,耐心等待即可。

安装好后,创建三个文件夹,分别为

nnUNet_raw
nnUNet_preprocessed
nnUNet_results

其中nnUNet_raw用于存放原始医学影像数据,nnUNet_preprocessed用于存放自动生成的网络结构和训练相关参数,nnUNet-results用于存放训练和测试结果。

找到nnunetv2的安装目录,一般为

/home/***/miniconda3/envs/new_env/lib/python3.10/site-packages/nnunetv2

如果直接装在root中,使用which pip查找即可,此时可能存在位置例如:

/root/miniconda3/lib/python3.10/site-packages/nnunetv2

其中***为ubuntu系统的账户名
同时,上述所有new_env都可以改成任意anaconda的虚拟环境名称。

在该文件夹中,找到paths.py文件,打开,将代码

nnUNet_raw=
nnUNet_preprocessed=
nnUNet_results=

的右边,修改为刚刚创建的三个文件夹的路径。

至此前期准备工作完成。

制作训练集

基于nnU-Net框架的各种开源医学图像分割网络可以针对使用的医学图像分割数据库,进行全自动的deep-supervised的patch-based的网络设计,网络训练,以及验证集测试。

如果我们仅想在自己的数据集的基础上跑一跑别人的代码,不修改网络结构或者调整训练参数。那么, 就可以直接将验证集视为测试集。这样一遍跑下来,网络训练和测试都完成了。

因此,使用nnU-Net框架的各种开源医学图像分割网络的第一要务,就是制作符合规范的数据集。

以下通过nnU-Net v2的官网给定的制作Dataset042_BraTS18的代码,讲解如何制作基于nnU-Net框架的各种开源医学图像分割网络的数据集。
代码的链接如下:
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/dataset_conversion/Dataset042_BraTS18.py

代码如下:

import multiprocessing
import shutil
import SimpleITK as sitk
import numpy as np
from tqdm import tqdm
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw


def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:
    # use this for segmentation only!!!
    # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3
    img = sitk.ReadImage(in_file)
    img_npy = sitk.GetArrayFromImage(img)

    uniques = np.unique(img_npy)
    for u in uniques:
        if u not in [0, 1, 2, 4]:
            raise RuntimeError('unexpected label')

    seg_new = np.zeros_like(img_npy)
    seg_new[img_npy == 4] = 3
    seg_new[img_npy == 2] = 1
    seg_new[img_npy == 1] = 2
    img_corr = sitk.GetImageFromArray(seg_new)
    img_corr.CopyInformation(img)
    sitk.WriteImage(img_corr, out_file)


def convert_labels_back_to_BraTS(seg: np.ndarray):
    new_seg = np.zeros_like(seg)
    new_seg[seg == 1] = 2
    new_seg[seg == 3] = 4
    new_seg[seg == 2] = 1
    return new_seg


def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):
    a = sitk.ReadImage(join(input_folder, filename))
    b = sitk.GetArrayFromImage(a)
    c = convert_labels_back_to_BraTS(b)
    d = sitk.GetImageFromArray(c)
    d.CopyInformation(a)
    sitk.WriteImage(d, join(output_folder, filename))


def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str,
                                                                num_processes: int = 12):
    """
    reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the
    """
    maybe_mkdir_p(output_folder)
    nii = subfiles(input_folder, suffix='.nii.gz', join=False)
    with multiprocessing.get_context("spawn").Pool(num_processes) as p:
        p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))


if __name__ == '__main__':
    brats_data_dir = ...

    task_id = 42
    task_name = "BraTS2018"

    foldername = "Dataset%03.0d_%s" % (task_id, task_name)

    # setting up nnU-Net folders
    out_base = join(nnUNet_raw, foldername)
    imagestr = join(out_base, "imagesTr")
    labelstr = join(out_base, "labelsTr")
    maybe_mkdir_p(imagestr)
    maybe_mkdir_p(labelstr)

    case_ids_hgg = subdirs(join(brats_data_dir, "HGG"), prefix='Brats', join=False)
    case_ids_lgg = subdirs(join(brats_data_dir, "LGG"), prefix="Brats", join=False)

    print("copying hggs")
    for c in tqdm(case_ids_hgg):
        shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
        shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
        shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
        shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))

        copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii"),
                                                             join(labelstr, c + '.nii'))
    print("copying lggs")
    for c in tqdm(case_ids_lgg):
        shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
        shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
        shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
        shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))

        copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii"),
                                                             join(labelstr, c + '.nii'))

    generate_dataset_json(out_base,
                          channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},
                          labels={
                              'background': 0,
                              'whole tumor': (1, 2, 3),
                              'tumor core': (2, 3),
                              'enhancing tumor': (3,)
                          },
                          num_training_cases=(len(case_ids_lgg) + len(case_ids_hgg)),
                          file_ending='.nii',
                          regions_class_order=(1, 2, 3),
                          license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
                          reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
                          dataset_release='1.0')

以下仅对main函数进行讲解:

模块1

    brats_data_dir = ...

    task_id = 42
    task_name = "BraTS2018"

    foldername = "Dataset%03.0d_%s" % (task_id, task_name)

    # setting up nnU-Net folders
    out_base = join(nnUNet_raw, foldername)
    imagestr = join(out_base, "imagesTr")
    labelstr = join(out_base, "labelsTr")
    maybe_mkdir_p(imagestr)
    maybe_mkdir_p(labelstr)

这部分是数据集的基础信息,不加解读,在自己使用时,task_id为100-999的任意数即可,task_name 自己定义即可。

模块2


    case_ids_hgg = subdirs(join(brats_data_dir, "HGG"), prefix='Brats', join=False)
    case_ids_lgg = subdirs(join(brats_data_dir, "LGG"), prefix="Brats", join=False)

    print("copying hggs")
    for c in tqdm(case_ids_hgg):
        shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
        shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
        shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
        shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))

        copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii"),
                                                             join(labelstr, c + '.nii'))
    print("copying lggs")
    for c in tqdm(case_ids_lgg):
        shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii"), join(imagestr, c + '_0000.nii'))
        shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii"), join(imagestr, c + '_0001.nii'))
        shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii"), join(imagestr, c + '_0002.nii'))
        shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii"), join(imagestr, c + '_0003.nii'))

        copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii"),
                                                             join(labelstr, c + '.nii'))

这部分为数据库转移的代码,其目的为二:

  1. 将影像数据转移到创建好的文件夹中
  2. 将手动分割结果的标签重编号,编号中0代表背景,前景编号从1开始递增,像1,2,3,这样。

模块3

    generate_dataset_json(out_base,
                          channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},
                          labels={
                              'background': 0,
                              'whole tumor': (1, 2, 3),
                              'tumor core': (2, 3),
                              'enhancing tumor': (3,)
                          },
                          num_training_cases=(len(case_ids_lgg) + len(case_ids_hgg)),
                          file_ending='.nii',
                          regions_class_order=(1, 2, 3),
                          license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
                          reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
                          dataset_release='1.0')

这部分代码是重点,制作供nnUNet解读的json文件。期中channel_names是网络输入的通道数和名称,labels是网络输出的通道数和名称。region_class_order为重叠区域分割的顺序,如果各前景区域不存在重叠,设置

region_class_order=(1)

网络训练

以下内容,默认大家更喜欢在vscode中,运行python代码,而不是通过命令行运行代码。

建议大家从nnunetv2的代码源文件中,复制并修改出如下代码,以方便使用python代码进行网络训练。

import os
import socket
from typing import Union, Optional
import shutil
import nnunetv2
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json
from nnunetv2.paths import nnUNet_preprocessed
from nnunetv2.run.load_pretrained_weights import load_pretrained_weights
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from torch.backends import cudnn
from torch.serialization import add_safe_globals
import torch._dynamo
torch._dynamo.config.suppress_errors = True

from nnunetv2.configuration import default_num_processes
from nnunetv2.experiment_planning.plan_and_preprocess_api import preprocess
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
from nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor
from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name
from nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed

from typing import List, Type, Optional, Tuple, Union
import pkgutil
from batchgenerators.utilities.file_and_folder_operations import *
import importlib
import numpy as np
add_safe_globals([np.generic])

import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.cache_size_limit = 0  # Disable caching



from nnunetv2.configuration import default_num_processes
from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, plan_experiments, preprocess
from torch.serialization import add_safe_globals
add_safe_globals([np.core.multiarray.scalar])
os.environ["TORCH_COMPILE"] = "0"

def plan_and_preprocess_entry(dataset_id,conf):
    class args:
        d = dataset_id
        fpe = 'DatasetFingerprintExtractor'
        npfp = 4
        verify_dataset_integrity = True
        no_pp = False
        clean = True
        pl = 'nnUNetPlannerResEncM'
        gpu_memory_target = None
        preprocessor_name = 'DefaultPreprocessor'
        overwrite_target_spacing = None
        overwrite_plans_name = 'nnUNetResEncUNetL'
        c = conf
        np = None
        verbose = False

    # fingerprint extraction
    print("Fingerprint extraction...")
    extract_fingerprints(args.d, args.fpe, args.npfp, args.verify_dataset_integrity, args.clean, args.verbose)

    # experiment planning
    print('Experiment planning...')
    plans_identifier = plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name,
                                        args.overwrite_target_spacing, args.overwrite_plans_name)

    # manage default np
    if args.np is None:
        default_np = {"2d": 8, "3d_fullres": 4, "3d_lowres": 8}
        np = [default_np[c] if c in default_np.keys() else 4 for c in args.c]
    else:
        np = args.np
    # preprocessing
    if not args.no_pp:
        print('Preprocessing...')
        preprocess(args.d, plans_identifier, args.c, np, args.verbose)

def find_free_network_port() -> int:
    """Finds a free port on localhost.

    It is useful in single-node training when we don't want to connect to a real main node but have to set the
    `MASTER_PORT` environment variable.
    """
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(("", 0))
    port = s.getsockname()[1]
    s.close()
    return port


def get_trainer_from_args(dataset_name_or_id: Union[int, str],
                          configuration: str,
                          fold: int,
                          trainer_name: str = 'nnUNetTrainer',
                          plans_identifier: str = 'nnUNetPlans',
                          use_compressed: bool = False,
                          device: torch.device = torch.device('cuda')):
    # load nnunet class and do sanity checks
    nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
                                                trainer_name, 'nnunetv2.training.nnUNetTrainer')
    if nnunet_trainer is None:
        raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in '
                           f'nnunetv2.training.nnUNetTrainer ('
                           f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere '
                           f'else, please move it there.')
    assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \
                                                    'nnUNetTrainer'

    # handle dataset input. If it's an ID we need to convert to int from string
    if dataset_name_or_id.startswith('Dataset'):
        pass
    else:
        try:
            dataset_name_or_id = int(dataset_name_or_id)
        except ValueError:
            raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern '
                             f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your '
                             f'input: {dataset_name_or_id}')

    # initialize nnunet trainer
    preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id))
    plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json')
    plans = load_json(plans_file)
    dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json'))
    nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,
                                    dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device)
    return nnunet_trainer


def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool,
                          pretrained_weights_file: str = None):
    if continue_training and pretrained_weights_file is not None:
        raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only '
                           'be used at the beginning of the training.')
    if continue_training:
        expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
        if not isfile(expected_checkpoint_file):
            expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth')
        # special case where --c is used to run a previously aborted validation
        if not isfile(expected_checkpoint_file):
            expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth')
        if not isfile(expected_checkpoint_file):
            print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to "
                               f"continue from. Starting a new training...")
            expected_checkpoint_file = None
    elif validation_only:
        expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
        if not isfile(expected_checkpoint_file):
            raise RuntimeError(f"Cannot run validation because the training is not finished yet!")
    else:
        if pretrained_weights_file is not None:
            if not nnunet_trainer.was_initialized:
                nnunet_trainer.initialize()
            load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True)
        expected_checkpoint_file = None

    if expected_checkpoint_file is not None:
        nnunet_trainer.load_checkpoint(expected_checkpoint_file)


def setup_ddp(rank, world_size):
    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup_ddp():
    dist.destroy_process_group()


def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, use_compressed, disable_checkpointing, c, val,
            pretrained_weights, npz, val_with_best, world_size):
    setup_ddp(rank, world_size)
    torch.cuda.set_device(torch.device('cuda', dist.get_rank()))

    nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p,
                                           use_compressed)

    if disable_checkpointing:
        nnunet_trainer.disable_checkpointing = disable_checkpointing

    assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.'

    maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights)

    if torch.cuda.is_available():
        cudnn.deterministic = False
        cudnn.benchmark = True

    if not val:
        nnunet_trainer.run_training()

    if val_with_best:
        nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))
    nnunet_trainer.perform_actual_validation(npz)
    cleanup_ddp()


def run_training(dataset_name_or_id: Union[str, int],
                 configuration: str, fold: Union[int, str],
                 trainer_class_name: str = 'nnUNetTrainer',
                 plans_identifier: str = 'nnUNetPlans',
                 pretrained_weights: Optional[str] = None,
                 num_gpus: int = 1,
                 use_compressed_data: bool = False,
                 export_validation_probabilities: bool = False,
                 continue_training: bool = False,
                 only_run_validation: bool = False,
                 disable_checkpointing: bool = False,
                 val_with_best: bool = False,
                 device: torch.device = torch.device('cuda')):
    if plans_identifier == 'nnUNetPlans':
        print("\n############################\n"
              "INFO: You are using the old nnU-Net default plans. We have updated our recommendations. "
              "Please consider using those instead! "
              "Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md"
              "\n############################\n")
    if isinstance(fold, str):
        if fold != 'all':
            try:
                fold = int(fold)
            except ValueError as e:
                print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
                raise e

    if val_with_best:
        assert not disable_checkpointing, '--val_best is not compatible with --disable_checkpointing'

    if num_gpus > 1:
        assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}"

        os.environ['MASTER_ADDR'] = 'localhost'
        if 'MASTER_PORT' not in os.environ.keys():
            port = str(find_free_network_port())
            print(f"using port {port}")
            os.environ['MASTER_PORT'] = port  # str(port)

        mp.spawn(run_ddp,
                 args=(
                     dataset_name_or_id,
                     configuration,
                     fold,
                     trainer_class_name,
                     plans_identifier,
                     use_compressed_data,
                     disable_checkpointing,
                     continue_training,
                     only_run_validation,
                     pretrained_weights,
                     export_validation_probabilities,
                     val_with_best,
                     num_gpus),
                 nprocs=num_gpus,
                 join=True)
    else:
        nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
                                               plans_identifier, use_compressed_data, device=device)

        if disable_checkpointing:
            nnunet_trainer.disable_checkpointing = disable_checkpointing

        assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'

        maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)

        if torch.cuda.is_available():
            cudnn.deterministic = False
            cudnn.benchmark = True

        if not only_run_validation:
            nnunet_trainer.run_training()

        if val_with_best:
            nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))
        nnunet_trainer.perform_actual_validation(export_validation_probabilities)


def run_training_entry(dataset_id,conf,folder,tran='nnUNetTrainer',plan='nnUNetPlans'):
    class args:
        dataset_name_or_id = dataset_id
        configuration = conf
        fold = folder
        tr = tran
        p = plan
        pretrained_weights = None
        num_gpus = 1
        use_compressed = False
        npz = True
        c = True
        val = False
        disable_checkpointing = False
        device = 'cuda'
        val_best = False


    assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
    if args.device == 'cpu':
        # let's allow torch to use hella threads
        import multiprocessing
        torch.set_num_threads(multiprocessing.cpu_count())
        device = torch.device('cpu')
    elif args.device == 'cuda':
        # multithreading in torch doesn't help nnU-Net if run on GPU
        # torch.set_num_threads(1)
        # torch.set_num_interop_threads(1)
        device = torch.device('cuda')
    else:
        device = torch.device('mps')

    run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
                 args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best,
                 device=device)

if __name__ == '__main__':
    # #add default informations
    plan_and_preprocess_entry([666],['3d_fullres'])
    run_training_entry('Dataset666_test','3d_fullres','0','nnUNetTrainer','nnUNetResEncUNetL')

再次强调,以上代码,均来自nnU-Net开源代码,对nnU-Net v2的源代码中,nnunetv2/run/run_training.py和nnunetv2/experiment_planning/plan_and_preprocess_api.py中的代码进行复制所得。

后续如果有人感兴趣,可以出一期对该代码的解读,以下先描述该代码的使用规范:

在main函数中,将第一行的[666]改成自己创建的数据集的标号,[‘3d_fullres’]可以按照nnU-Net官网的使用手册,改为[‘2d’],[‘3d_lowres’],以及[‘3d_cascade_fullres’]。

将第二行的Dataset666_test改为自己创建数据集的标号_自己创建数据集的名称。3d_fullres和第一行中的3d_fullres一样进行修改。

至此,基于Res encoder的nnU-Net网络就准备好了,接下来就可以愉快的进行网络训练和测试了。

一些nnU-Netv2的源文件的修改建议:

大家也不想在网络训练的过程中报错吧!
那么,找到nnUNetTrainer.py文件,通常位于路径

/root/miniconda3/lib/python3.10/site-packages/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py

进行以下修改:
1.

checkpoint = torch.load(filename_or_checkpoint, map_location=self.device)

改为

checkpoint = torch.load(filename_or_checkpoint, map_location=self.device, weights_only = False)

否则出现报错:

  File "/root/miniconda3/lib/python3.10/site-packages/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 1184, in load_checkpoint
    checkpoint = torch.load(filename_or_checkpoint, map_location=self.device)
  File "/root/miniconda3/lib/python3.10/site-packages/torch/serialization.py", line 1470, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值