一. 环境
- 安装基本环境(自行安装)
- 创建python的虚拟环境:
conda create -n nnUnet python=3.9
- 查看服务器版本,开GPU
nvidia-smi
- 安装对应版本的Pytorch:
-
官网:Pytorch版本
-
复制、运行
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
- 查看服务器环境
-
第一种方法:运行程序
-
第二种方法:直接查看
import torchvision
import torch
print(torch.version.cuda)
print(torch.__version__)
print(torchvision.__version__)
print(torch.cuda.is_available())
# 11.3
# 1.12.0
# 0.13.0
# True
二. 安装nnunet:两种方法
- 下载nnunet。两种方法任选一
git clone git://github.com/MIC-DKFZ/nnUNet.git
- 下载nnunet.zip。百度云提取码:f8xf
- 进入nnUNet文件夹
cd nnUNet
- 安装所需依赖包,向终端添加新的命令,这些命令用于后续整个nnU-Net pipeline的执行,这些命令都有一个前缀:nnUNetv2_ (两种方法任选一)
- 第一种方法:
pip install -e .
最好用镜像安装,速度快,如下代码:
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
- 第二种方法:pip install -e .相当于python setup.py,也就是运行文件夹中的setup.py文件
python setup.py install
- 安装隐藏层(可选,建议安装),hiddenlayer 使 nnU-net 能够生成网络拓扑图。
- 第一种方法:终端运行
pip install --upgrade git+https://github.com/FabianIsensee/hiddenlayer.git
- 第二种方法:
- 下载hiddenlayer.zip,官网链接https://github.com/FabianIsensee/hiddenlayer.git,或是直接下载的百度云提取码:ysoq
- 解压zip,查看setup.py
- 安装hiddenlayer
python setup.py install
三. 数据集结构化处理
3.1 新建文件夹
在nnUNet-wh文件夹下,新建DATASET。
- 依次新建文件夹:dataset_conversion、nnUNet_preprocessed、nnUNet_raw、nnUNet_trained_models(直接建
nnUNet_results
方便)。 - 在nnUNet_raw文件中,新建文件夹Dataset040_KiTS。
3.2 设置 nnUNet 读取文件的路径
- 命令行,确保在nnunet激活环境下,输入vim ~/.bashrc,然后点击键盘insert开始插入,在bashrc文末,添加如下三行代码。然后按住Esc键,输入:再输入wq,就可以保存退出。
vim ~/.bashrc
'''
说明,这里是路径是你自己的路径,就是上一步创建的三个文件夹的路径(这部分说明不需要写进去,只需要以下三行代码)
'''
export nnUNet_raw="/root/autodl-fs/nnUNet-wh/DATASET/nnUNet_raw"
export nnUNet_preprocessed="/root/autodl-fs/nnUNet-wh/DATASET/nnUNet_preprocessed"
export nnUNet_results="/root/autodl-fs/nnUNet-wh/DATASET/nnUNet_trained_models"
然后命令行输入source ~/.bashrc
,确保激活路径。
source ~/.bashrc
重点: 然后分别键入三个echo $nnUNet_results,验证是否可以识别。不能识别,后续无法进行数据预处理。
echo $nnUNet_results
echo $nnUNet_raw
echo $nnUNet_preprocessed
3.3 数据集重命名
- 在dataset_conversion文件中,新建一个Dataset040_KiTS.py文件
第一种方法代码:
import os
import json
import shutil
def save_json(obj, file, indent=4, sort_keys=True):
with open(file, 'w') as f:
json.dump(obj, f, sort_keys=sort_keys, indent=indent)
def maybe_mkdir_p(directory):
directory = os.path.abspath(directory)
splits = directory.split("/")[1:]
for i in range(0, len(splits)):
if not os.path.isdir(os.path.join("/", *splits[:i + 1])):
try:
os.mkdir(os.path.join("/", *splits[:i + 1]))
except FileExistsError:
# this can sometimes happen when two jobs try to create the same directory at the same time,
# especially on network drives.
print("WARNING: Folder %s already existed and does not need to be created" % directory)
def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
if join:
l = os.path.join
else:
l = lambda x, y: y
res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
and (prefix is None or i.startswith(prefix))
and (suffix is None or i.endswith(suffix))]
if sort:
res.sort()
return res
base = "/root/autodl-tmp/kits" # 原始数据集路径
out = "/root/autodl-fs/nnUNet-wh/DATASET/nnUNet_raw/Dataset040_KiTS" # 结构化数据集目录
cases = subdirs(base, join=False)
maybe_mkdir_p(out)
maybe_mkdir_p(os.path.join(out, "imagesTr"))
maybe_mkdir_p(os.path.join(out, "imagesTs"))
maybe_mkdir_p(os.path.join(out, "labelsTr"))
for c in cases:
case_id = int(c.split("_")[-1])
if case_id < 210:
shutil.copy(os.path.join(base, c, "imaging.nii.gz"), os.path.join(out, "imagesTr", c + "_0000.nii.gz"))
shutil.copy(os.path.join(base, c, "segmentation.nii.gz"), os.path.join(out, "labelsTr", c + ".nii.gz"))
else:
shutil.copy(os.path.join(base, c, "imaging.nii.gz"), os.path.join(out, "imagesTs", c + "_0000.nii.gz"))
json_dict = {}
"""
name: 数据集名字
dexcription: 对数据集的描述
modality: 模态,0表示CT数据,1表示MR数据。nnU-Net会根据不同模态进行不同的预处理(nnunet-v2版本改为channel_names)
labels: label中,不同的数值代表的类别(v1版本和v2版本的键值对刚好是反过来的)
file_ending: nnunet v2新加的
numTraining: 训练集数量
numTest: 测试集数量
training: 训练集的image 和 label 地址对
test: 只包含测试集的image. 这里跟Training不一样
"""
json_dict['name'] = "KiTS"
json_dict['description'] = "kidney and kidney tumor segmentation"
json_dict['tensorImageSize'] = "4D"
json_dict['reference'] = "KiTS data for nnunet"
json_dict['licence'] = ""
json_dict['release'] = "0.0"
json_dict['channel_names'] = {
"0": "CT",
}
json_dict['labels'] = {
"background": "0",
"Kidney": "1",
"Tumor": "2"
}
json_dict['numTraining'] = len(cases) # 应该是210例,直接写210
json_dict['file_ending'] = ".nii.gz"
json_dict['numTest'] = 0
json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in cases]
#json_dict['test'] = []
save_json(json_dict, os.path.join(out, "dataset.json"))
这个方法生成的文件如图:
第二种方法代码:
我在autodl-fs/nnUNet-wh/DATASET/dataset_conversion/Dataset210_KiTS2019.py
路径下新建了一个py文件,复制nnUNetV2版本中的autodl-fs/nnUNet-wh/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py
,需要修改标签和数量。具体修改内容如下:
将标签修改为[0 1 2],因为kits19只有背景0,肾脏1和癌症2,用于训练的数量为210个,
而kits23多了一个标签3,用于训练的数据为220个。
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_raw
def convert_kits2023(kits_base_dir: str, nnunet_dataset_id: int = 209):
task_name = "KiTS2019"
foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) # 生成的文件名字:Dataset209_KiTS2019
# setting up nnU-Net folders
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")
# imagests = join(out_base, "imagesTs") # 生成测试集
maybe_mkdir_p(imagestr)
maybe_mkdir_p(labelstr)
maybe_mkdir_p(imagests)
cases = subdirs(kits_base_dir, prefix='case_', join=False)
for tr in cases:
case_id = int(tr.split("_")[-1])
if case_id < 210:
shutil.copy(join(kits_base_dir, tr, 'imaging.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))
shutil.copy(join(kits_base_dir, tr, 'segmentation.nii.gz'), join(labelstr, f'{tr}.nii.gz'))
else:
pass
# shutil.copy(join(kits_base_dir, tr, 'imaging.nii.gz'), join(imagests, f'{tr}_0000.nii.gz')) #
generate_dataset_json(out_base, {0: "CT"},
labels={
"background": 0,
"kidney": 1,
"tumor": 2
},
# regions_class_order=(1, 3, 2),
num_training_cases=210, file_ending='.nii.gz',
dataset_name=task_name, reference='none',
release='prerelease',
description="KiTS2019")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('input_folder', type=str,
help="The downloaded and extracted KiTS2023 dataset (must have case_XXXXX subfolders)")
parser.add_argument('-d', required=False, type=int, default=209, help='nnU-Net Dataset ID, default: 220')
args = parser.parse_args()
amos_base = args.input_folder
convert_kits2023(amos_base, args.d)
# /root/autodl-tmp/kits
终端运行:
python Dataset210_KiTS2019.py /root/autodl-tmp/kits
总结: 这两种方法生成的文件夹不一样,所以dataset.json
文件也不一样。
3.4 数据预处理(开GPU)
3.4.1 重采样
重采样的知识参考:【医学图像预处理过程】
3.4.1.1 肾脏肿瘤分割实战(KiTS19)
- 在nnunet里面有重采样,但是冠军方法采用的是将所有病例的体素间距重采样为
3.22 x 1.62 x 1.62.
nnUnet肾脏肿瘤分割实战(KiTS19) - 用CPU的时候运行时出现错误Killed,与下面的预处理遇到相同问题。开GPU跑没有问题。
- 为了防止原文件出错,新建reshaping的文件夹。
import numpy as np
import SimpleITK as sitk
import os
'''
算法功能:进行重采样,将所有病例的体素间距重采样为 3.22 x 1.62 x 1.62.
代码出现错误:Killed.
'''
# 定义插值函数
def transform(image,newSpacing, resamplemethod=sitk.sitkNearestNeighbor):
# 设置一个Filter
resample = sitk.ResampleImageFilter()
# 初始的体素块尺寸
originSize = image.GetSize()
# 初始的体素间距
originSpacing = image.GetSpacing()
newSize = [
int(np.round(originSize[0] * originSpacing[0] / newSpacing[0])),
int(np.round(originSize[1] * originSpacing[1] / newSpacing[1])),
int(np.round(originSize[2] * originSpacing[2] / newSpacing[2]))
]
print('current size:',newSize)
# 沿着x,y,z,的spacing(3)
# The sampling grid of the output space is specified with the spacing along each dimension and the origin.
resample.SetOutputSpacing(newSpacing)
# 设置original
resample.SetOutputOrigin(image.GetOrigin())
# 设置方向
resample.SetOutputDirection(image.GetDirection())
resample.SetSize(newSize)
# 设置插值方式
resample.SetInterpolator(resamplemethod)
# 设置transform
resample.SetTransform(sitk.Euler3DTransform())
# 默认像素值 resample.SetDefaultPixelValue(image.GetPixelIDValue())
return resample.Execute(image)
# 给image进行插值,采用 B样条 插值
data_path = "/root/autodl-fs/nnUNet-wh/DATASET/nnUNet_raw/Dataset209_KiTS2019/imagesTr"
# data_path311 = "/root/autodl-tmp/nnUNet/dataset/nnUNet_raw/Dataset040_KiTS/imagesTr"
data_path311 = "/root/autodl-fs/nnUNet-wh/DATASET/nnUNet_raw/Dataset311_KiTS209/imagesTr"
for path in sorted(os.listdir(data_path)):
print(path)
img_path = os.path.join(data_path,path)
img_itk = sitk.ReadImage(img_path)
print('origin size:', img_itk.GetSize())
# image采用sitk.sitkBSpline插值
new_itk = transform(img_itk, [3.22, 1.62, 1.62], sitk.sitkBSpline) # sitk.sitkLinear
# sitk.WriteImage(new_itk, img_path)
data_path3 = os.path.join(data_path311,path)
sitk.WriteImage(new_itk, data_path3)
print('images is resampled!')
print('-'*20)
# 给mask进行插值,采用 NearestNeighbor 插值
label_path = "/root/autodl-fs/nnUNet-wh/DATASET/nnUNet_raw/Dataset209_KiTS2019/labelsTr"
label_path311 = "/root/autodl-fs/nnUNet-wh/DATASET/nnUNet_raw/Dataset311_KiTS209/labelsTr"
for path in sorted(os.listdir(label_path)):
print(path)
img_path = os.path.join(label_path,path)
img_itk = sitk.ReadImage(img_path)
print('origin size:', img_itk.GetSize())
# segment采用sitk.sitkNearestNeighbor插值
new_itk = transform(img_itk, [3.22, 1.62, 1.62])
# sitk.WriteImage(new_itk, img_path)
label_path3 = os.path.join(label_path311,path)
sitk.WriteImage(new_itk, label_path3)
print('labels is resampled!')
3.4.1.1 肝脏肿瘤分割实战(LiTS17)
- 没有进行重采样,但是代码出现问题:
Error: Spacing mismatch between segmentation and corresponding images.
Spacing images: (244, 512, 512).
Spacing seg: (244, 512, 512).
Image files: ['/autodl-fs/data/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/imagesTr/liver_00048_0000.nii.gz'].
Seg file: /autodl-fs/data/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/labelsTr/liver_00048.nii.gz
Warning: Origin mismatch between segmentation and corresponding images.
Origin images: (-249.10000610351562, 249.02317810058594, -651.0).
Origin seg: (-1.0, -1.0, 1.0).
Image files: ['/autodl-fs/data/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/imagesTr/liver_00048_0000.nii.gz'].
Seg file: /autodl-fs/data/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/labelsTr/liver_00048.nii.gz
Warning: Direction mismatch between segmentation and corresponding images.
Direction images: (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0).
Direction seg: (-1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0).
Image files: ['/autodl-fs/data/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/imagesTr/liver_00048_0000.nii.gz'].
Seg file: /autodl-fs/data/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/labelsTr/liver_00048.nii.gz
RuntimeError: Some images have errors. Please check text output above to see which one(s) and what's going on.
- 错误显示上面48、49、50、51和52图像和标签出现三个问题:
- 空间(Spacing)不匹配:
错误信息显示分割标签(segmentation)与相应图像在空间维度上的不匹配。这通常意味着:图像和分割文件的体素大小(voxel size)不同。可能需要对图像或分割文件进行重新采样(resampling)。 - 原点(Origin)不匹配:
警告信息指出分割标签的原点与图像的原点不一致。原点位置不一致可能会导致后续处理中的错误。 - 方向(Direction)不匹配:
图像和分割标签的方向也不一致。方向矩阵定义了图像数据在三维空间中的排列顺序。
- 空间(Spacing)不匹配:
- 用3D软件查看了47(没问题)和48(有问题)图像。显示如下:
- 解决方法一:空间、原点和方向 (推荐)
import SimpleITK as sitk
import os
def correct_segmentation_metadata(image_path, seg_path, output_seg_path):
# Load image and segmentation
image = sitk.ReadImage(image_path)
seg = sitk.ReadImage(seg_path)
# Check and correct spacing
if image.GetSpacing() != seg.GetSpacing():
print(f"Correcting spacing for segmentation: {seg_path}")
seg.SetSpacing(image.GetSpacing())
# Check and correct origin
if image.GetOrigin() != seg.GetOrigin():
print(f"Correcting origin for segmentation: {seg_path}")
seg.SetOrigin(image.GetOrigin())
# Check and correct direction
if image.GetDirection() != seg.GetDirection():
print(f"Correcting direction for segmentation: {seg_path}")
seg.SetDirection(image.GetDirection())
# Save the corrected segmentation
sitk.WriteImage(seg, output_seg_path)
print(f"Corrected segmentation saved to: {output_seg_path}")
def process_dataset(image_dir, seg_dir, output_seg_dir):
if not os.path.exists(output_seg_dir):
os.makedirs(output_seg_dir)
# Loop through image files and find corresponding segmentation files
for image_file in os.listdir(image_dir):
image_path = os.path.join(image_dir, image_file)
# Assuming segmentation files have similar names but without "_0000" suffix
seg_file = image_file.replace("_0000", "")
seg_path = os.path.join(seg_dir, seg_file)
if os.path.exists(seg_path):
output_seg_path = os.path.join(output_seg_dir, seg_file)
correct_segmentation_metadata(image_path, seg_path, output_seg_path)
else:
print(f"Segmentation file not found for image: {image_file}")
# Set your paths here
image_dir = '/root/autodl-fs/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/imagesTr'
seg_dir = '/root/autodl-fs/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/labelsTr'
output_seg_dir = '/root/autodl-fs/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/labelsTr_corrected'
# Run the correction
process_dataset(image_dir, seg_dir, output_seg_dir)
注意: 这个代码会校正空间、原点和方向中任一方面有问题
的标签,但是我们上面的错误只是48、49、50、51和52,因此我只替换了这5组的标签,其他的标签没有替换。
- 解决方法-重采样(1,1,1)
代码同KiTS19
- 最后执行代码进行预处理
nnUNetv2_plan_and_preprocess -d 209 --verify_dataset_integrity
3.4.2 nnunet预处理
nnUNetv2_plan_and_preprocess -d 209 --verify_dataset_integrity
用CPU时,数据预处理出现错误,错误显示:Killed.
用GPU时,数据预处理正常显示:
注意: 此时就自动设定了batch_size
,patch_size
。进行预处理后可以查看nnUNetPlans.json
- 修改
batch_size
,patch_size
,也可以参考nnUNet参数batch_size和patch_size的修改方式
可以查看autodl-fs/u-Mamba-main/umamba/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py
第345行进行修改。也可以详细看一下第229行的get_plans_for_configuration
的定义。
从autodl-fs/nnUNet-wh/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py
可以看下给出的例子。
3.4.3 数据增强
细节参考:nnUnet代码解读–数据增强
- 查看
autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
- 数据读取,在
autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
第578行,具体细节在
autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/dataloading/nnunet_dataset.py
(1) nnUNetTrainer.py`第578行.如下:
(2)autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/dataloading/nnunet_dataset.py
如下:
- 数据加载,查看
autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/dataloading/data_loader_3d.py
在autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/dataloading/data_loader_3d.py
第24行,填充边界。
-
填充边界, 查看
autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/dataloading/base_data_loader.py
第64行-80行。
(1)底部边界,从第82行到135行:
(2)顶部边界,第137行:
(3)有效边界(未填充),在autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/dataloading/data_loader_3d.py
第32-33行:
-
数据增强,
autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/data_augmentation/
(1) 获取patch尺寸,autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/data_augmentation/compute_initial_patch_size.py
(2) 变换方法,autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/training/data_augmentation/custom_transforms
- cascade_transforms:
- MoveSegAsOneHotToData:该类实现了将特定索引的分割标签转换为 one-hot 编码并附加到目标数据中的功能。
- RemoveRandomConnectedComponentFromOneHotEncodingTransform:该类的主要功能是在数据预处理过程中随机移除指定通道中的小连通区域,并可以选择用其他类别的值填充这些区域。
- ApplyRandomBinaryOperatorTransform:该类的主要功能是在指定的通道上随机应用二值操作(如膨胀、腐蚀、闭运算、开运算),并随机选择结构元素的大小。
- deep_supervision_donwsampling:该类的主要功能是对输入的分割数据进行多尺度降采样,并将结果存储在输出字典中。
- limited_length_multithreaded_augmenter:虚拟长度: my_imaginary_length 参数允许用户为数据增强器设置一个虚拟的长度值,这在某些情况下可能用于控制数据增强的次数或批次大小。
- manipulating_data_dict:该类的主要功能是从传入的数据字典中移除指定的键值对。
- masking:该类的主要功能是根据提供的掩码将数据中指定通道的值设置为某个特定值。
- region_based_training:该类的主要功能是将分割图中的特定区域合并为新的区域,并将结果存储在输出字典中。
- transforms_for_dummy_2d:
- Convert3DTo2DTransform:该类的主要功能是将输入的5维数组(形状为 (b, c, x, y, z))转换为4维数组(形状为 (b, c * x, y, z)),通过将通道维度和空间维度合并来实现。
- Convert2DTo3DTransform:该类的主要功能是将4D数组(形状为 (b, c * x, y, z))转换回5D数组(形状为 (b, c, x, y, z))。
# nnunetV1的数据增强:
SegChannelSelectionTransform:标签如果有多个通道,可以选择一个通道。(我觉得用不到,标签一般都是单通道的)
SpatialTransform:终极空间变换器,包括旋转、变形、缩放、裁剪。
GammaTransform:Gamma变换,对输入图像灰度值进行非线性操作,使输出图像灰度值与输入图像灰度值呈指数关系
MirrorTransform:镜像变换,沿着轴随机镜像翻转,每个轴默认的翻转概率是0.5
MaskTransform:data[mask < 0] = 0,将mask之外(mask小于0)的部分置零
RemoveLabelTransform:替换标签值,如RemoveLabelTransform(-1, 0)就是将标签为-1的替换为0
RenameTransform:重命名data_dict,或者把data_dict中的seg部分丢掉,不算数据增强
NumpyToTensor:顾名思义,numpy数组转为tensor
3.4.4 遇到的问题
- 问题1
RuntimeError: Some background worker is 6 feet under. Yuck.
OK jokes aside.
One of your background processes is missing. This could be because of an error (look for an error message) or because it was killed by your OS due to running out of RAM. If you don't see an error message, out of RAM is likely the problem. In that case reducing the number of workers might help
- 解决方法: 开双卡跑或是换个内存大的服务器(48G)。
- 问题2:一个或多个预期为.npz (NumPy压缩文件) 的文件已损坏,或者根本就不是.npz文件。
Traceback (most recent call last):
File "/root/miniconda3/envs/whma/bin/nnUNetv2_train", line 33, in <module>
sys.exit(load_entry_point('nnunetv2', 'console_scripts', 'nnUNetv2_train')())
File "/autodl-fs/data/LightM-UNet-master/lightm-unet/nnunetv2/run/run_training.py", line 268, in run_training_entry
run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
File "/autodl-fs/data/LightM-UNet-master/lightm-unet/nnunetv2/run/run_training.py", line 204, in run_training
nnunet_trainer.run_training()
File "/autodl-fs/data/LightM-UNet-master/lightm-unet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 1250, in run_training
self.on_train_start()
File "/autodl-fs/data/LightM-UNet-master/lightm-unet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 823, in on_train_start
unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
File "/autodl-fs/data/LightM-UNet-master/lightm-unet/nnunetv2/training/dataloading/utils.py", line 113, in unpack_dataset
p.starmap(_convert_to_npy, zip(npz_files,
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/pool.py", line 375, in starmap
return self._map_async(func, iterable, starmapstar, chunksize).get()
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/pool.py", line 774, in get
raise self._value
zipfile.BadZipFile: File is not a zip file
解决办法: 检查npz文件,发现92.93.94.95文件有问题,重新复制粘贴。
四、训练
修改了epoch=500,我跑过一轮发现400左右就差不多了。
在文件autodl-fs/nnUNet-wh/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
的第147行。
4.1 终端运行代码
209为数据文件夹编号,2d为模型(2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres
) ,0
为五折交叉验证中的第0折(0-4),即210个数据分5份,其中168个数据用来训练,42个数据用来验证。all
是210个数据均用来训练,得到一个模型,训练完后会验证所有数据(210个)。5
不是5折交叉验证,nnUNet会以4:1的比例随机选择训练集和验证集,来自nnUNetV2使用教程,超详细!!(使用MSD十项全能数据集)
4.1.1 一折一折的跑
nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h]
依次执行:
nnUNetv2_train 209 2d 0
# nnUNetv2_train 209 2d 1
# nnUNetv2_train 209 2d 2
# nnUNetv2_train 209 2d 3
# nnUNetv2_train 209 2d 4
# nnUNetv2_train 209 2d all
或
nnUNetv2_train 208 3d_fullres all -tr nnUNetTrainerUMambaBot
注: 如果您希望使用单个模型进行预测,请将all
fold进行训练,并在nnUNetv2_predict -f all
中使用的。
- 终端运行保存日志(我没试过)
nnUNetv2_train 209 2d 0 > train.log 2>&1
- 运行中断,继续运行
nnUNetv2_train 209 2d 0 --c
注: 训练后的结构在autodl-fs/nnUNet-wh/DATASET/nnUNet_trained_models/Dataset209_KiTS2019/nnUNetTrainer__nnUNetPlans__2d/fold_1
第0
折的验证集42个数据预测结果在 autodl-fs/nnUNet-wh/DATASET/nnUNet_trained_models/Dataset209_KiTS2019/nnUNetTrainer__nnUNetPlans__2d/fold_0/validation
4.1.2 自动运行5折
- 可以创建一个tst.sh 文件,并写入此代码:
for fold in {0..4}
do
# echo "nnUNetv2_train 1 3d_lowres $fold"
nnUNetv2_train 1 3d_lowres $fold
done
- 在虚拟环境下的终端运行此 tst.sh 文件,使用 source命令可以执行脚本(为什么在shell脚本中无法使用cd?问题原因及解决方法、nnU-Net v2的环境配置到训练自己的数据集),具体命令行为:
source /root/autodl-tmp/nnU-Net/sts.sh
4.2 训练完成,结果展示
- 训练结果在
autodl-fs/LightM-UNet-master/data/nnUNet_results/Dataset208_LiTS2017/nnUNetTrainerLightMUNet__nnUNetPlans__3d_fullres
- fold_all为结果文件,具体:
- debug.json:包含用于训练此模型的蓝图和推断参数的摘要。不容易阅读,但对调试非常有用。
- model_best.model/model_best.model.pkl:训练期间识别的最佳模型的检查点文件。
- model_final_checkpoint.model/model_final_checkpoint.model.pkl:最终模型的检查点文件(训练结束后)。这是用于验证和推理的。
- networkarchitecture.pdf(仅当安装了hiddenlayer时!):一个pdf文档,其中包含网络架构图。
- progress.png:训练期间训练(蓝色)和验证(红色)损失的图。还显示了评估指标的近似值(绿色)。这个近似值是前景类的平均Dice分数。
- validation_raw:在这个文件夹中是训练完成后预测的验证案例。summary.json包含验证度量(文件末尾提供了所有情况的平均值)。
- training_log:训练过程中不断打印,nnunet的loss函数默认是趋向-1的,也就是说在训练的过程中,我们通过每轮训练的日志可以查看到每轮的loss函数,这个数值应该是负数,而且越趋向于-1,效果越好。
- 看我们训练结果,有两种方法:
- 打开progress.png来从图像上直观的感受一下。
- 打开validation_raw/summary.json,从里面我们不仅可以看到对每一个验证数据的评价,更可以在最末尾看到它们的平均值。
summary.json文件内容:
4.3 如果预测时,想用五折验证的话,终端运行代码
这会让 nnU-Net 在最终验证期间
保存 softmax 输出。它们是必需的。导出的 softmax 预测非常大,因此可能占用大量磁盘空间,因此默认情况下不启用此功能。如果您最初没有使用--npz
标记运行,但现在需要 softmax 预测,请使用以下命令重新运行验证:
nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD --val --npz
依次执行:
nnUNetv2_train 209 2d 0 --val --npz
nnUNetv2_train 209 2d 1 --val --npz
nnUNetv2_train 209 2d 1 --val --npz
nnUNetv2_train 209 2d 1 --val --npz
nnUNetv2_train 209 2d 1 --val --npz
此处实际是把验证集42个数据重新跑下,保存 softmax 输出。
4.4 训练遇到的问题
参考:nnunet(二) Common Issues and their Solutions
- 运行一段时间,中断后,重新运行时,出现错误:
ValueError: mmap length is greater than file size
或
Exception in background worker 2:
mmap length is greater than file size
Traceback (most recent call last):
File "/root/miniconda3/envs/whma/lib/python3.10/site-packages/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py", line 53, in producer
item = next(data_loader)
File "/root/miniconda3/envs/whma/lib/python3.10/site-packages/batchgenerators/dataloading/data_loader.py", line 126, in __next__
return self.generate_train_batch()
File "/autodl-fs/data/u-Mamba-main/umamba/nnunetv2/training/dataloading/data_loader_3d.py", line 19, in generate_train_batch
data, seg, properties = self._data.load_case(i)
File "/autodl-fs/data/u-Mamba-main/umamba/nnunetv2/training/dataloading/nnunet_dataset.py", line 86, in load_case
data = np.load(entry['data_file'][:-4] + ".npy", 'r')
File "/root/miniconda3/envs/whma/lib/python3.10/site-packages/numpy/lib/npyio.py", line 453, in load
return format.open_memmap(file, mode=mmap_mode,
File "/root/miniconda3/envs/whma/lib/python3.10/site-packages/numpy/lib/format.py", line 945, in open_memmap
marray = numpy.memmap(filename, dtype=dtype, shape=shape, order=order,
File "/root/miniconda3/envs/whma/lib/python3.10/site-packages/numpy/core/memmap.py", line 268, in __new__
mm = mmap.mmap(fid.fileno(), bytes, access=acc, offset=start)
ValueError: mmap length is greater than file size
Traceback (most recent call last):
File "/root/miniconda3/envs/whma/bin/nnUNetv2_train", line 33, in <module>
sys.exit(load_entry_point('nnunetv2', 'console_scripts', 'nnUNetv2_train')())
File "/autodl-fs/data/u-Mamba-main/umamba/nnunetv2/run/run_training.py", line 268, in run_training_entry
run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
File "/autodl-fs/data/u-Mamba-main/umamba/nnunetv2/run/run_training.py", line 204, in run_training
nnunet_trainer.run_training()
File "/autodl-fs/data/u-Mamba-main/umamba/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 1258, in run_training
train_outputs.append(self.train_step(next(self.dataloader_train)))
File "/root/miniconda3/envs/whma/lib/python3.10/site-packages/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py", line 196, in __next__
item = self.__get_next_item()
File "/root/miniconda3/envs/whma/lib/python3.10/site-packages/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py", line 181, in __get_next_item
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
RuntimeError: One or more background workers are no longer alive. Exiting. Please check the print statements above for the actual error message
- 解决方法:进入指定文件夹中
autodl-fs/u-Mamba-main/data/nnUNet_preprocessed/Dataset208_LiTS2017/nnUNetPlans_3d_fullres
,删除nnUNet_preprocessed文件夹中所有的.npy文件,再运行一次。
rm *.npy
五、预测
5.1 单个模型测试
5.1.1 单个模型测试
- 单个模型预测,下面例子
-f 0
是第0折模型。
nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -d DATASET_NAME_OR_ID -c CONFIGURATION --save_probabilities
- 例:
nnUNetv2_predict -i ./data/nnUNet_raw/Dataset208_LiTS2017/imagesTs -o ./evaluation/segs/Dataset208_UMambaBot -d 208 -c 3d_fullres -f all -tr nnUNetTrainerUMambaBot --disable_tta
或
nnUNetv2_predict -i ${nnUNet_raw}/Dataset209_KiTS2019/imagesTs -o output0 -d 209 -c 2d -f 0
# nnUNetv2_predict -i ${nnUNet_raw}/Dataset209_KiTS2019/imagesTs -o output1 -d 209 -c 2d -f 1
# nnUNetv2_predict -i ${nnUNet_raw}/Dataset209_KiTS2019/imagesTs -o output2 -d 209 -c 2d -f 2
# nnUNetv2_predict -i ${nnUNet_raw}/Dataset209_KiTS2019/imagesTs -o output3 -d 209 -c 2d -f 3
# nnUNetv2_predict -i ${nnUNet_raw}/Dataset209_KiTS2019/imagesTs -o output4 -d 209 -c 2d -f 4
# nnUNetv2_predict -i ${nnUNet_raw}/Dataset209_KiTS2019/imagesTs -o output_all -d 209 -c 2d -f all
5.1.2 训练时发现已经收敛,不想继续训练,预测
- 预设1000epoch,但是训练到500epoch,发现已经收敛模型(每50个epoch更新一次模型参数和训练收敛图)。此时暂停后,该如何预测:
- 首先,我们查看预测的代码,在文件
autodl-fs/LightM-UNet-master/lightm-unet/nnunetv2/inference/predict_from_raw_data.py
下,查看各个参数,在第721行。 - 其次,
ctrl + f
查看initialize_from_trained_model_folder
,第68行,我们发现保存的模型名为checkpoint_final.pth
。 - 再次,我们发现第759行默认为
checkpoint_final.pth
。 - 我们中断训练后,模型保存了两个模型,一个是
checkpoint_latest.pth
,另一个是checkpoint_best.pth
。所以我们只需要把checkpoint_latest.pth
重命名为checkpoint_final.pth
,当然如果我们想用最好的模型,预测代码可以加-chk checkpoint_best.pth
nnUNetv2_predict -i ./data/nnUNet_raw/Dataset208_LiTS2017/imagesTs -o ./evaluation/segs/Dataset208_LightMUNetAsppbest -d 208 -c 3d_fullres -f all -chk checkpoint_best.pth -tr nnUNetTrainerLightMUNetAspp --disable_tta
5.2 寻找最佳配置
- 自动确定最佳配置
一旦所需的配置经过训练(完全5折交叉验证),您就可以告诉 nnU-Net 自动识别最适合您的组合:
nnU-Netv2在服务器上使用全流程(小白边踩坑边学习的记录)
nnUNetv2_find_best_configuration DATASET_NAME_OR_ID -c CONFIGURATIONS
nnUNetv2_find_best_configuration -h
查看参数
nnUNetv2_find_best_configuration 101 -f 0 1 2 3 4 -c 2d
5.3 在5.1和5.2的基础上,集成模型预测
- 集成模型,可以包括任意数量的文件夹,由
nnUNetv2_predict
生成的带npz的预测文件夹。
nnUNetv2_ensemble -i FOLDER1 FOLDER2 ... -o OUTPUT_FOLDER -np NUM_PROCESSES
5.4 后处理
nnUNetv2_apply_postprocessing -i FOLDER_WITH_PREDICTIONS -o OUTPUT_FOLDER --pp_pkl_file POSTPROCESSING_FILE -plans_json PLANS_FILE -dataset_json DATASET_JSON_FILE
5.5 结果压缩,得到predictions.zip压缩包
zip predictions.zip prediction_*.nii.gz
5.6 预测遇到的问题
- 找不到预测文件夹,原因是预测指令写错了
Traceback (most recent call last):
File "/root/miniconda3/envs/whma/bin/nnUNetv2_predict", line 33, in <module>
sys.exit(load_entry_point('nnunetv2', 'console_scripts', 'nnUNetv2_predict')())
File "/autodl-fs/data/LightM-UNet-master/lightm-unet/nnunetv2/inference/predict_from_raw_data.py", line 826, in predict_entry_point
predictor.initialize_from_trained_model_folder(
File "/autodl-fs/data/LightM-UNet-master/lightm-unet/nnunetv2/inference/predict_from_raw_data.py", line 77, in initialize_from_trained_model_folder
dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
File "/root/miniconda3/envs/whma/lib/python3.10/site-packages/batchgenerators/utilities/file_and_folder_operations.py", line 68, in load_json
with open(file, 'r') as f:
FileNotFoundError: [Errno 2] No such file or directory: '/autodl-fs/data/LightM-UNet-master/data/nnUNet_results/Dataset208_LiTS2017/nnUNetTrainerLightMUNetAspp--disable_tta__nnUNetPlans__3d_fullres/dataset.json'
- 解决方法:
--disable_tta
前面少了空格
# 修改前
nnUNetv2_predict -i ./data/nnUNet_raw/Dataset208_LiTS2017/imagesTs -o ./evaluation/segs/Dataset208_LightMUNetAspp -d 208 -c 3d_fullres -f all -tr nnUNetTrainerLightMUNetAspp--disable_tta
# 修改后
nnUNetv2_predict -i ./data/nnUNet_raw/Dataset208_LiTS2017/imagesTs -o ./evaluation/segs/Dataset208_LightMUNetAspp -d 208 -c 3d_fullres -f all -tr nnUNetTrainerLightMUNetAspp --disable_tta
- 预测时,背景不存在
Traceback (most recent call last):
File "/root/miniconda3/envs/whma/bin/nnUNetv2_predict", line 33, in <module>
sys.exit(load_entry_point('nnunetv2', 'console_scripts', 'nnUNetv2_predict')())
File "/autodl-fs/data/U-Mamba-main/umamba/nnunetv2/inference/predict_from_raw_data.py", line 831, in predict_entry_point
predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
File "/autodl-fs/data/U-Mamba-main/umamba/nnunetv2/inference/predict_from_raw_data.py", line 250, in predict_from_files
return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)
File "/autodl-fs/data/U-Mamba-main/umamba/nnunetv2/inference/predict_from_raw_data.py", line 366, in predict_from_data_iterator
proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
File "/autodl-fs/data/U-Mamba-main/umamba/nnunetv2/utilities/file_path_utilities.py", line 103, in check_workers_alive_and_busy
raise RuntimeError('Some background workers are no longer alive')
RuntimeError: Some background workers are no longer alive
Process SpawnProcess-8:
Traceback (most recent call last):
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/autodl-fs/data/U-Mamba-main/umamba/nnunetv2/inference/data_iterators.py", line 57, in preprocess_fromfiles_save_to_queue
raise e
File "/autodl-fs/data/U-Mamba-main/umamba/nnunetv2/inference/data_iterators.py", line 50, in preprocess_fromfiles_save_to_queue
target_queue.put(item, timeout=0.01)
File "<string>", line 2, in put
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/managers.py", line 833, in _callmethod
raise convert_to_error(kind, result)
multiprocessing.managers.RemoteError:
---------------------------------------------------------------------------
Traceback (most recent call last):
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/managers.py", line 260, in serve_client
self.id_to_local_proxy_obj[ident]
KeyError: '7f22e4050e80'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/managers.py", line 262, in serve_client
raise ke
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/managers.py", line 256, in serve_client
obj, exposed, gettypeid = id_to_obj[ident]
KeyError: '7f22e4050e80'
---------------------------------------------------------------------------
Process SpawnProcess-10:
Traceback (most recent call last):
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/autodl-fs/data/U-Mamba-main/umamba/nnunetv2/inference/data_iterators.py", line 57, in preprocess_fromfiles_save_to_queue
raise e
File "/autodl-fs/data/U-Mamba-main/umamba/nnunetv2/inference/data_iterators.py", line 50, in preprocess_fromfiles_save_to_queue
target_queue.put(item, timeout=0.01)
File "<string>", line 2, in put
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/managers.py", line 818, in _callmethod
kind, result = conn.recv()
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/connection.py", line 250, in recv
buf = self._recv_bytes()
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/connection.py", line 414, in _recv_bytes
buf = self._recv(4)
File "/root/miniconda3/envs/whma/lib/python3.10/multiprocessing/connection.py", line 383, in _recv
raise EOFError
EOFError
- 原因:内存不够
- 解决方法:换个内存大的,或者删除一些文件,释放内存
六、评价
6.1 nnunet自带评价
python abdomen_DSC_Eval.py --gt_path ~/autodl-fs/u-Mamba-main/data/nnUNet_raw/Dataset208_LiTS2017/labelsTs --seg_path ./segs/Dataset208_UMambaBot --save_path ./Dataset208_UMambaBot.csv
修改文件夹,运行,查看summary.json报告。
6.2 自己计算
二分类,label只有0和1
,三维nii
数据(如果是二维数据,需要给一下数据读取方式。)实现所有指标,并将结果保存为Excel
。【理论+实践】史上最全-论文中常用的图像分割评价指标-附完整代码
# 计算三维下各种指标
from __future__ import absolute_import, print_function
import pandas as pd
import GeodisTK
import numpy as np
from scipy.ndimage import binary_erosion, generate_binary_structure
from scipy import ndimage
import os
import nibabel as nib
# pixel accuracy
def binary_pa(s, g):
"""
calculate the pixel accuracy of two N-d volumes.
s: the segmentation volume of numpy array
g: the ground truth volume of numpy array
"""
pa = ((s == g).sum()) / g.size
return pa
# Dice evaluation
def binary_dice(s, g):
"""
calculate the Dice score of two N-d volumes.
s: the segmentation volume of numpy array
g: the ground truth volume of numpy array
"""
assert (len(s.shape) == len(g.shape))
prod = np.multiply(s, g)
s0 = prod.sum()
dice = (2.0 * s0 + 1e-10) / (s.sum() + g.sum() + 1e-10)
return dice
# IOU evaluation
def binary_iou(s, g):
assert (len(s.shape) == len(g.shape))
# 两者相乘值为1的部分为交集
intersecion = np.multiply(s, g)
# 两者相加,值大于0的部分为交集
union = np.asarray(s + g > 0, np.float32)
iou = intersecion.sum() / (union.sum() + 1e-10)
return iou
# Hausdorff and ASSD evaluation
def get_edge_points(img):
"""
get edge points of a binary segmentation result
"""
dim = len(img.shape)
if dim == 2:
strt = generate_binary_structure(2, 1)
else:
strt = generate_binary_structure(3, 1) # 三维结构元素,与中心点相距1个像素点的都是邻域
ero = binary_erosion(img, strt)
edge = np.asarray(img, np.uint8) - np.asarray(ero, np.uint8)
return edge
def binary_hausdorff95(s, g, spacing=None):
"""
get the hausdorff distance between a binary segmentation and the ground truth
inputs:
s: a 3D or 2D binary image for segmentation
g: a 2D or 2D binary image for ground truth
spacing: a list for image spacing, length should be 3 or 2
"""
s_edge = get_edge_points(s)
g_edge = get_edge_points(g)
image_dim = len(s.shape)
assert (image_dim == len(g.shape))
if (spacing == None):
spacing = [1.0] * image_dim
else:
assert (image_dim == len(spacing))
img = np.zeros_like(s)
if (image_dim == 2):
s_dis = GeodisTK.geodesic2d_raster_scan(img, s_edge, 0.0, 2)
g_dis = GeodisTK.geodesic2d_raster_scan(img, g_edge, 0.0, 2)
elif (image_dim == 3):
s_dis = GeodisTK.geodesic3d_raster_scan(img, s_edge, spacing, 0.0, 2)
g_dis = GeodisTK.geodesic3d_raster_scan(img, g_edge, spacing, 0.0, 2)
dist_list1 = s_dis[g_edge > 0]
dist_list2 = g_dis[s_edge > 0]
if len(dist_list1) == 0 or len(dist_list2) == 0:
return np.inf # Return infinity for invalid cases
dist1 = sorted(dist_list1)[int(len(dist_list1) * 0.95)]
dist2 = sorted(dist_list2)[int(len(dist_list2) * 0.95)]
return max(dist1, dist2)
# 平均表面距离
def binary_assd(s, g, spacing=None):
"""
get the average symetric surface distance between a binary segmentation and the ground truth
inputs:
s: a 3D or 2D binary image for segmentation
g: a 2D or 2D binary image for ground truth
spacing: a list for image spacing, length should be 3 or 2
"""
s_edge = get_edge_points(s)
g_edge = get_edge_points(g)
image_dim = len(s.shape)
assert (image_dim == len(g.shape))
if (spacing == None):
spacing = [1.0] * image_dim
else:
assert (image_dim == len(spacing))
img = np.zeros_like(s)
if (image_dim == 2):
s_dis = GeodisTK.geodesic2d_raster_scan(img, s_edge, 0.0, 2)
g_dis = GeodisTK.geodesic2d_raster_scan(img, g_edge, 0.0, 2)
elif (image_dim == 3):
s_dis = GeodisTK.geodesic3d_raster_scan(img, s_edge, spacing, 0.0, 2)
g_dis = GeodisTK.geodesic3d_raster_scan(img, g_edge, spacing, 0.0, 2)
ns = s_edge.sum()
ng = g_edge.sum()
if ns == 0 or ng == 0:
return np.inf # Avoid division by zero
s_dis_g_edge = s_dis * g_edge
g_dis_s_edge = g_dis * s_edge
assd = (s_dis_g_edge.sum() + g_dis_s_edge.sum()) / (ns + ng)
return assd
# relative volume error evaluation
def binary_relative_volume_error(s_volume, g_volume):
s_v = float(s_volume.sum())
g_v = float(g_volume.sum())
assert (g_v > 0)
rve = abs(s_v - g_v) / g_v
return rve
def compute_class_sens_spec(pred, label):
"""
Compute sensitivity and specificity for a particular example
for a given class for binary.
Args:
pred (np.array): binary arrary of predictions, shape is
(height, width, depth).
label (np.array): binary array of labels, shape is
(height, width, depth).
Returns:
sensitivity (float): precision for given class_num.
specificity (float): recall for given class_num
"""
tp = np.sum((pred == 1) & (label == 1))
tn = np.sum((pred == 0) & (label == 0))
fp = np.sum((pred == 1) & (label == 0))
fn = np.sum((pred == 0) & (label == 1))
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
return sensitivity, specificity
def get_evaluation_score(s_volume, g_volume, spacing, metric):
if (len(s_volume.shape) == 4):
assert (s_volume.shape[0] == 1 and g_volume.shape[0] == 1)
s_volume = np.reshape(s_volume, s_volume.shape[1:])
g_volume = np.reshape(g_volume, g_volume.shape[1:])
if (s_volume.shape[0] == 1):
s_volume = np.reshape(s_volume, s_volume.shape[1:])
g_volume = np.reshape(g_volume, g_volume.shape[1:])
metric_lower = metric.lower()
if (metric_lower == "dice"):
score = binary_dice(s_volume, g_volume)
elif (metric_lower == "iou"):
score = binary_iou(s_volume, g_volume)
elif (metric_lower == 'assd'):
score = binary_assd(s_volume, g_volume, spacing)
elif (metric_lower == "hausdorff95"):
score = binary_hausdorff95(s_volume, g_volume, spacing)
elif (metric_lower == "rve"):
score = binary_relative_volume_error(s_volume, g_volume)
elif (metric_lower == "volume"):
voxel_size = 1.0
for dim in range(len(spacing)):
voxel_size = voxel_size * spacing[dim]
score = g_volume.sum() * voxel_size
else:
raise ValueError("unsupported evaluation metric: {0:}".format(metric))
return score
if __name__ == '__main__':
seg_path = '你的分割结果文件夹'
gd_path = '你的label文件夹'
save_dir = 'excel 存放文件夹'
seg = sorted(os.listdir(seg_path))
dices, hds, rves, case_names, senss, specs = [], [], [], [], [], []
for name in seg:
if not name.startswith('.') and name.endswith('nii.gz'):
# 加载label and segmentation image
seg_ = nib.load(os.path.join(seg_path, name))
seg_arr = seg_.get_fdata().astype('float32')
gd_ = nib.load(os.path.join(gd_path, name))
gd_arr = gd_.get_fdata().astype('float32')
case_name.append(name)
# 求hausdorff95距离
hd_score = get_evaluation_score(seg_arr, gd_arr, spacing=None, metric='hausdorff95')
hds.append(hd_score)
# 求体积相关误差
rve = get_evaluation_score(seg_arr, gd_arr, spacing=None, metric='rve')
rves.append(rve)
# 求dice
dice = get_evaluation_score(seg_.get_fdata(), gd_.get_fdata(), spacing=None, metric='dice')
dices.append(dice)
# 敏感度,特异性
sens, spec = compute_class_sens_spec(seg_.get_fdata(), gd_.get_fdata())
senss.append(sens)
specs.append(spec)
# 存入pandas
data = {'dice': dices, 'RVE': rves, 'Sens': senss, 'Spec': specs, 'HD95': hds}
df = pd.DataFrame(data=data, columns=['dice', 'RVE', 'Sens', 'Spec', 'HD95'], index=case_name)
df.to_csv(os.path.join(save_dir, 'metrics.csv'))
七、数据处理-excel快速实现"平均值±标准差"
- 选中插入公式的地方
- 插入函数,复制公式
=ROUND(AVERAGE(B2:B39),4) *100&"±"&ROUND(STDEV(B2:Bx), 4) *100
- 依次选中红框的数字,然后点number1,选中一列数据
- 得到计算的平均值和标准差,然后右拉,直接计算其他两列的值。