import itertools
import multiprocessing
import os
from time import sleep
from typing import Tuple, Union, List
import numpy as np
import onnxruntime as ort
from acvl_utils.cropping_and_padding.padding import pad_nd_image
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \
save_json
from nnunetv2.inference.data_iterators import preprocessing_iterator_fromfiles
# from hhh1 import preprocessing_iterator_fromfiles
from nnunetv2.inference.export_prediction import export_prediction_from_logits, \
convert_predicted_logits_to_segmentation_with_correct_shape
from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \
compute_steps_for_sliding_window
from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder
default_num_processes = 8
class nnUNetONNXPredictor(object):
def __init__(self,
tile_step_size: float = 0.5,
use_gaussian: bool = True,
use_mirroring: bool = True,
perform_everything_on_device: bool = True,
device: str = 'cuda',
verbose: bool = False,
verbose_preprocessing: bool = False,
allow_tqdm: bool = True):
self.verbose = verbose
self.verbose_preprocessing = verbose_preprocessing
self.allow_tqdm = allow_tqdm
self.plans_manager, self.configuration_manager, self.dataset_json, \
self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None
self.tile_step_size = tile_step_size
self.use_gaussian = use_gaussian
self.use_mirroring = use_mirroring
# 设备配置
self.device = device
if device.startswith('cuda'):
self.provider = ['CUDAExecutionProvider']
elif device.startswith('cpu'):
self.provider = ['CPUExecutionProvider']
elif device.startswith('mps'):
self.provider = [
'CoreMLExecutionProvider'] if 'CoreMLExecutionProvider' in ort.get_available_providers() else [
'CPUExecutionProvider']
else:
raise ValueError(f"不支持的设备: {device}")
# 非CUDA设备禁用device上操作
if not device.startswith('cuda'):
print(f'perform_everything_on_device=True 仅支持cuda设备! 已自动设置为False')
perform_everything_on_device = False
self.perform_everything_on_device = perform_everything_on_device
# ONNX相关
self.onnx_sessions = [] # 多折模型会话
self.input_name = None # 输入名称
self.output_name = None # 输出名称
# 配置参数
self.configuration_name = '3d_fullres'
self.inference_allowed_mirroring_axes = None
def initialize_from_onnx_model_folder(self, onnx_model_folder: str,
use_folds: Union[Tuple[Union[int, str]], None],
configuration_name: str = '3d_fullres'):
"""从ONNX模型文件夹初始化(核心:加载模型、配置和数据集信息)"""
if use_folds is None:
use_folds = nnUNetONNXPredictor.auto_detect_available_folds(onnx_model_folder)
# 加载必要文件
dataset_json = load_json(join(onnx_model_folder, 'dataset.json'))
plans = load_json(join(onnx_model_folder, 'plans.json'))
plans_manager = PlansManager(plans)
if isinstance(use_folds, str):
use_folds = [use_folds]
# 加载ONNX模型
onnx_sessions = []
for f in use_folds:
f = int(f) if f != 'all' else f
onnx_path = join(onnx_model_folder, f'fold_{f}', 'model.onnx')
if not isfile(onnx_path):
raise FileNotFoundError(f"ONNX模型文件不存在: {onnx_path}")
# 设备ID配置
provider_options = None
if self.device.startswith('cuda'):
device_parts = self.device.split(':')
device_id = int(device_parts[-1]) if len(device_parts) > 1 else 0
provider_options = [{'device_id': device_id}]
# 创建会话
session = ort.InferenceSession(onnx_path, providers=self.provider, provider_options=provider_options)
onnx_sessions.append(session)
# 获取输入输出名称
self.input_name = onnx_sessions[0].get_inputs()[0].name
self.output_name = onnx_sessions[0].get_outputs()[0].name
# 配置初始化
self.configuration_name = configuration_name if configuration_name else next(
iter(plans['configurations'].keys()))
self.trainer_name = 'nnUNetTrainer' # 默认训练器
spatial_dims = len(plans['configurations'][self.configuration_name]['patch_size'])
self.inference_allowed_mirroring_axes = tuple(range(spatial_dims)) # 镜像轴自动配置
# 保存核心参数
self.plans_manager = plans_manager
self.configuration_manager = plans_manager.get_configuration(self.configuration_name)
self.onnx_sessions = onnx_sessions
self.dataset_json = dataset_json
self.label_manager = plans_manager.get_label_manager(dataset_json)
self.allowed_mirroring_axes = self.inference_allowed_mirroring_axes
@staticmethod
def auto_detect_available_folds(onnx_model_folder):
"""自动检测可用的fold"""
print('use_folds为None,自动检测可用fold')
fold_folders = subdirs(onnx_model_folder, prefix='fold_', join=False)
fold_folders = [i for i in fold_folders if i != 'fold_all' and isfile(join(onnx_model_folder, i, 'model.onnx'))]
use_folds = [int(i.split('_')[-1]) for i in fold_folders]
print(f'找到fold: {use_folds}')
return use_folds
def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]],
output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]],
folder_with_segs_from_prev_stage: str = None,
overwrite: bool = True,
part_id: int = 0,
num_parts: int = 1,
save_probabilities: bool = False):
"""处理输入输出列表(核心:分配任务、过滤已完成病例)"""
if isinstance(list_of_lists_or_source_folder, str):
# 从文件夹创建病例列表
list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(
list_of_lists_or_source_folder, self.dataset_json['file_ending'])
# 任务分配(按part_id拆分)
list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts]
caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in
list_of_lists_or_source_folder]
print(f'处理 {part_id}/{num_parts},共 {len(caseids)} 个病例')
# 输出路径处理
if isinstance(output_folder_or_list_of_truncated_output_files, str):
output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids]
elif isinstance(output_folder_or_list_of_truncated_output_files, list):
output_filename_truncated = output_folder_or_list_of_truncated_output_files[part_id::num_parts]
else:
output_filename_truncated = None
# 前序分割结果路径
seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending'])
if folder_with_segs_from_prev_stage else None for i in caseids]
# 过滤已完成病例(如果不覆盖)
if not overwrite and output_filename_truncated is not None:
keep = []
for i, of in enumerate(output_filename_truncated):
seg_exists = isfile(of + self.dataset_json['file_ending'])
prob_exists = isfile(of + '.npz') if save_probabilities else True
if seg_exists and prob_exists:
continue
keep.append(i)
output_filename_truncated = [output_filename_truncated[i] for i in keep]
list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in keep]
seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in keep]
print(f'跳过已完成病例,剩余 {len(keep)} 个待处理')
return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files
def predict_from_files(self,
input_folder: str,
output_folder: str,
save_probabilities: bool = False,
overwrite: bool = True,
num_processes_preprocessing: int = default_num_processes,
num_processes_segmentation_export: int = default_num_processes,
folder_with_segs_from_prev_stage: str = None,
num_parts: int = 1,
part_id: int = 0):
"""核心入口:从文件夹输入推理"""
# 输出文件夹初始化(保存配置)
maybe_mkdir_p(output_folder)
init_kwargs = {k: v for k, v in locals().items() if k != 'self'}
recursive_fix_for_json_export(init_kwargs)
save_json(init_kwargs, join(output_folder, 'predict_args.json'))
save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)
# 级联模型检查
if self.configuration_manager.previous_stage_name is not None:
assert folder_with_segs_from_prev_stage is not None, \
f'级联模型需要前序分割结果,请通过folder_with_segs_from_prev_stage指定'
# 处理输入输出列表
input_list, output_list, prev_seg_list = self._manage_input_and_output_lists(
input_folder, output_folder, folder_with_segs_from_prev_stage, overwrite, part_id, num_parts,
save_probabilities)
if len(input_list) == 0:
print('无病例需要处理')
return
# 创建数据迭代器(预处理)
data_iterator = preprocessing_iterator_fromfiles(
input_list, prev_seg_list, output_list,
self.plans_manager, self.dataset_json, self.configuration_manager,
num_processes_preprocessing, self.device.startswith('cuda'), self.verbose_preprocessing
)
# 执行推理
self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)
def predict_from_data_iterator(self,
data_iterator,
save_probabilities: bool = False,
num_processes_segmentation_export: int = default_num_processes):
"""从数据迭代器推理(核心:滑动窗口预测+结果导出)"""
# 初始化导出进程池
with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool:
worker_list = [i for i in export_pool._pool]
export_tasks = []
for preprocessed in data_iterator:
# 加载预处理数据
data = preprocessed['data']
if isinstance(data, str):
# 临时文件加载
data = np.load(data)
os.remove(data)
else:
data = data.numpy()
output_path = preprocessed['ofile']
if output_path:
print(f'开始预测: {os.path.basename(output_path)}')
else:
print(f'开始预测: 图像形状 {data.shape}')
# 检查进程状态(避免任务堆积)
while check_workers_alive_and_busy(export_pool, worker_list, export_tasks, allowed_num_queued=2):
sleep(0.1)
# 滑动窗口预测
logits = self.predict_logits_from_preprocessed_data(data)
# 提交导出任务
properties = preprocessed['data_properties']
if output_path:
export_tasks.append(export_pool.starmap_async(
export_prediction_from_logits,
((logits, properties, self.configuration_manager, self.plans_manager,
self.dataset_json, output_path, save_probabilities),)
))
else:
export_tasks.append(export_pool.starmap_async(
convert_predicted_logits_to_segmentation_with_correct_shape,
((logits, self.plans_manager, self.configuration_manager, self.label_manager,
properties, save_probabilities),)
))
if output_path:
print(f'预测完成: {os.path.basename(output_path)}(结果导出中)')
# 等待所有导出任务完成
for task in export_tasks:
task.get()
print('所有病例处理完成')
def predict_logits_from_preprocessed_data(self, data: np.ndarray) -> np.ndarray:
"""从预处理数据预测logits(核心:多折融合+滑动窗口)"""
# 保存原始线程数配置
n_threads = 8 # 简化处理
prediction = None
# 多折模型融合
for session in self.onnx_sessions:
fold_logits = self.predict_sliding_window_return_logits(data, session)
if prediction is None:
prediction = fold_logits
else:
prediction += fold_logits
# 平均多折结果
if len(self.onnx_sessions) > 1:
prediction /= len(self.onnx_sessions)
# 恢复线程配置
return prediction
def predict_sliding_window_return_logits(self, input_image: np.ndarray,
session: ort.InferenceSession) -> np.ndarray:
"""滑动窗口预测(核心:分块预测+权重融合)"""
assert input_image.ndim == 4, '输入必须是4D数组 (c, x, y, z)'
if self.verbose:
print(f'输入形状: {input_image.shape}, 滑动窗口大小: {self.configuration_manager.patch_size}')
# 图像Padding(确保能被滑动窗口整除)
data, revert_pad_slicer = pad_nd_image(
input_image, self.configuration_manager.patch_size, 'constant', {'value': 0}, True, None
)
# 生成滑动窗口切片
slicers = self._internal_get_sliding_window_slicers(data.shape[1:])
# 初始化预测结果和计数
logits_shape = (self.label_manager.num_segmentation_heads, *data.shape[1:])
predicted_logits = np.zeros(logits_shape, dtype=np.float16)
prediction_counts = np.zeros(data.shape[1:], dtype=np.float16)
# 高斯权重(提升中心区域权重)
gaussian = None
if self.use_gaussian:
gaussian_tensor = compute_gaussian(tuple(self.configuration_manager.patch_size),
sigma_scale=1. / 8, value_scaling_factor=10, device='cpu')
gaussian = gaussian_tensor.numpy().astype(np.float16)
# 滑动窗口推理
from tqdm import tqdm
for sl in tqdm(slicers, desc='滑动窗口预测', disable=not self.allow_tqdm):
# 提取窗口数据
window_data = data[sl][None] # 添加batch维度
# 预测(含镜像增强)
window_logits = self._internal_maybe_mirror_and_predict(window_data, session)[0] # 移除batch维度
# 应用高斯权重
if self.use_gaussian:
window_logits *= gaussian
prediction_counts[sl[1:]] += gaussian
else:
prediction_counts[sl[1:]] += 1
# 累加结果
predicted_logits[sl] += window_logits.astype(np.float16)
# 归一化(除以计数)
predicted_logits /= prediction_counts
# 移除Padding
predicted_logits = predicted_logits[(slice(None), *revert_pad_slicer[1:])]
return predicted_logits
def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):
"""生成滑动窗口切片(核心:计算窗口位置)"""
if len(self.configuration_manager.patch_size) < len(image_size):
# 2D切片(适应3D图像的通道维度)
steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size,
self.tile_step_size)
slicers = []
for d in range(image_size[0]):
for sx in steps[0]:
for sy in steps[1]:
slicers.append(tuple([slice(None), d,
slice(sx, sx + self.configuration_manager.patch_size[0]),
slice(sy, sy + self.configuration_manager.patch_size[1])]))
else:
# 3D切片
steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size,
self.tile_step_size)
slicers = []
for sx in steps[0]:
for sy in steps[1]:
for sz in steps[2]:
slicers.append(tuple([slice(None),
slice(sx, sx + self.configuration_manager.patch_size[0]),
slice(sy, sy + self.configuration_manager.patch_size[1]),
slice(sz, sz + self.configuration_manager.patch_size[2])]))
return slicers
def _internal_maybe_mirror_and_predict(self, x: np.ndarray, session: ort.InferenceSession) -> np.ndarray:
"""带镜像增强的预测(TTA)"""
# 基础预测
pred = session.run([self.output_name], {self.input_name: x.astype(np.float32)})[0]
# 镜像增强(TTA)
if self.use_mirroring and self.allowed_mirroring_axes:
mirror_axes = [m + 2 for m in self.allowed_mirroring_axes] # 调整轴索引(适配数据形状)
# 生成所有镜像组合
mirror_combinations = [c for i in range(len(mirror_axes)) for c in
itertools.combinations(mirror_axes, i + 1)]
for axes in mirror_combinations:
x_flipped = np.flip(x, axes)
pred_flipped = session.run([self.output_name], {self.input_name: x_flipped.astype(np.float32)})[0]
pred += np.flip(pred_flipped, axes) # 翻转回原始方向
# 平均所有预测(原始+镜像)
pred /= (len(mirror_combinations) + 1)
return pred
def onnx_predict_entry_point():
"""命令行入口(用户通过命令行输入文件夹路径)"""
import argparse
parser = argparse.ArgumentParser(description='nnU-Net ONNX模型推理(文件夹输入)')
parser.add_argument('-i', type=str, required=True, help='输入文件夹(含图像文件)')
parser.add_argument('-o', type=str, required=True, help='输出文件夹(保存分割结果)')
parser.add_argument('-m', type=str, required=True, help='ONNX模型文件夹(含fold_X子文件夹)')
parser.add_argument('-f', nargs='+', type=str, default=(0, 1, 2, 3, 4), help='使用的fold(默认0-4)')
parser.add_argument('-c', type=str, default='3d_fullres', help='配置名称(如3d_fullres)')
parser.add_argument('-step_size', type=float, default=0.5, help='滑动窗口步长(0.5-1.0,越小越准)')
parser.add_argument('--disable_tta', action='store_true', help='禁用镜像增强(加速)')
parser.add_argument('--save_probabilities', action='store_true', help='保存概率图')
parser.add_argument('--continue_prediction', action='store_true', help='继续中断的预测(不覆盖)')
parser.add_argument('-npp', type=int, default=3, help='预处理进程数')
parser.add_argument('-nps', type=int, default=3, help='导出进程数')
parser.add_argument('-prev_stage_predictions', type=str, default=None, help='前序分割结果文件夹(级联模型用)')
parser.add_argument('-device', type=str, default='cuda', help='设备(cuda/cpu/mps)')
parser.add_argument('--disable_progress_bar', action='store_true', help='禁用进度条')
args = parser.parse_args()
args.f = [i if i == 'all' else int(i) for i in args.f]
# 创建输出文件夹
maybe_mkdir_p(args.o)
# 初始化预测器
predictor = nnUNetONNXPredictor(
tile_step_size=args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
device=args.device,
allow_tqdm=not args.disable_progress_bar
)
# 加载模型
predictor.initialize_from_onnx_model_folder(
onnx_model_folder=args.m,
use_folds=args.f,
configuration_name=args.c
)
# 开始推理
predictor.predict_from_files(
input_folder=args.i,
output_folder=args.o,
save_probabilities=args.save_probabilities,
overwrite=not args.continue_prediction,
num_processes_preprocessing=args.npp,
num_processes_segmentation_export=args.nps,
folder_with_segs_from_prev_stage=args.prev_stage_predictions
)
if __name__ == '__main__':
# 示例:代码内调用(指定文件夹路径)
predictor = nnUNetONNXPredictor(
tile_step_size=0.5,
use_mirroring=True,
device='cuda',
allow_tqdm=True
)
predictor.initialize_from_onnx_model_folder(
onnx_model_folder=r'D:\dd\nnUNet-master\nnUNetFrame\nnUNet_results\Dataset997_data\nnUNetTrainer__nnUNetPlans__3d_fullres', # ONNX模型文件夹
use_folds=(0,), # 使用的fold
configuration_name='3d_fullres'
)
predictor.predict_from_files(
input_folder='D:/dd/nnUNet-master/nnUNetFrame/test', # 输入图像文件夹
output_folder='D:/dd/nnUNet-master/nnUNetFrame/testresult_onnx', # 输出结果文件夹
save_probabilities=False
)
# # 命令行入口(优先)
# onnx_predict_entry_point()
这是我现在用nnunetv2训练的模型,然后我转成onnx模型了之后的推理代码。你看一下能不能简化一下我只要推理效果不变就行。多线程什么的功能不需要