nnUNetDataset
参数列表
class nnUNetDataset(object):
def __init__(self, folder: str, case_identifiers: List[str] = None,
num_images_properties_loading_threshold: int = 0,
folder_with_segs_from_previous_stage: str = None):
folder
具体为nnUNet_preprocessed/Dataset001_Xxx/nnUNetPlans_2d/
case_identifiers
list类型:['1','2',...'n']
splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json")
dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None,num_images_properties_loading_threshold=0,
folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)
# if the split file does not exist we need to create it
if not isfile(splits_file):
self.print_to_log_file("Creating new 5-fold cross-validation split...")
all_keys_sorted = list(np.sort(list(dataset.keys())))
splits = generate_crossval_split(all_keys_sorted, seed=12345, n_splits=5)
save_json(splits, splits_file)
当do_split()划分交叉验证数据时, 为None,get_case_identifiers(folder)读取folder目录下的所有npz文件名,保存到dataset中。其中dataset[][]一级索引为get_case_identifiers得到的case id,二级索引有两个,分别是'data_file'和'properties_file',表示id的对应npz文件和pkl文件路径。其中npz文件为原始文件,包含'data'和'seg'两个数据。
划分交叉验证数据时的nnUNetDataset:
import os
from typing import List
import numpy as np
import shutil
from batchgenerators.utilities.file_and_folder_operations import join, load_pickle, isfile
from nnunetv2.training.dataloading.utils import get_case_identifiers
class nnUNetDataset(object):
def __init__(self, folder: str, case_identifiers: List[str] = None,
num_images_properties_loading_threshold: int = 0,
folder_with_segs_from_previous_stage: str = None):
super().__init__()
# print('loading dataset')
if case_identifiers is None:
case_identifiers = get_case_identifiers(folder)
case_identifiers.sort()
self.dataset = {}
for c in case_identifiers:
self.dataset[c] = {}
self.dataset[c]['data_file'] = join(folder, f"{c}.npz")
self.dataset[c]['properties_file'] = join(folder, f"{c}.pkl")
def __getitem__(self, key):
ret = {**self.dataset[key]}
if 'properties' not in ret.keys():
ret['properties'] = load_pickle(ret['properties_file'])
return ret
def __setitem__(self, key, value):
return self.dataset.__setitem__(key, value)
def keys(self):
return self.dataset.keys()
def __len__(self):
return self.dataset.__len__()
def items(self):
return self.dataset.items()
def values(self):
return self.dataset.values()
def load_case(self, key):
entry = self[key]
if 'open_data_file' in entry.keys():
data = entry['open_data_file']
# print('using open data file')
elif isfile(entry['data_file'][:-4] + ".npy"):
data = np.load(entry['data_file'][:-4] + ".npy", 'r')
if self.keep_files_open:
self.dataset[key]['open_data_file'] = data
# print('saving open data file')
else:
data = np.load(entry['data_file'])['data']
if 'open_seg_file' in entry.keys():
seg = entry['open_seg_file']
# print('using open data file')
elif isfile(entry['data_file'][:-4] + "_seg.npy"):
seg = np.load(entry['data_file'][:-4] + "_seg.npy", 'r')
if self.keep_files_open:
self.dataset[key]['open_seg_file'] = seg
# print('saving open seg file')
else:
seg = np.load(entry['data_file'])['seg']
if 'seg_from_prev_stage_file' in entry.keys():
if isfile(entry['seg_from_prev_stage_file'][:-4] + ".npy"):
seg_prev = np.load(entry['seg_from_prev_stage_file'][:-4] + ".npy", 'r')
else:
seg_prev = np.load(entry['seg_from_prev_stage_file'])['seg']
seg = np.vstack((seg, seg_prev[None]))
return data, seg, entry['properties']
getitem时产生'properties',与'data_file'和'properties_file'同级,内容格式为:
'properties': {
'spacing': (999, 1, 1),
'shape_before_cropping': (1, 512, 512),
'bbox_used_for_cropping': [[0, 1], [0, 512], [0, 512]],
'shape_after_cropping_and_before_resampling': (1, 512, 512),
'class_locations': {
np.int64(1): array([[ 0, 0, 367, 369],
[ 0, 0, 392, 227],
[ 0, 0, 394, 181],
...,
[ 0, 0, 214, 152],
[ 0, 0, 192, 163],
[ 0, 0, 338, 154]])
}
}
num_images_properties_loading_threshold
当case数小于一定阈值时,将所有pkl文件加载到ram中
folder_with_segs_from_previous_stage
nnUNetDataLoader2D
参数列表
data
nnUNetDataset类型,一直向父类以_data传递
batch_size
patch_size
(602,602)=(512,512)/0.85
final_patch_size
(512,512)
label_manager
oversample_foreground_percent
0.33