nnunetv2系列:单个2D图像推理
对于单个图像,使用官方默认的推理方式会比较慢,而且在测试阶段推荐先用cpu而不是gpu进行推理,对于大批量的图像推荐使用官方默认的推理方式,且使用gpu。下面的代码同样根据官方给的示例进行调整得到推理单个2D图像的代码。相比默认的推理方式,这里支持图像名不以_0000.png结尾。为了方便人查看预测的效果,增加了恢复自定义的标签值模块。
代码示例
import numpy as np
from torch import device
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.imageio.natural_image_reader_writer import NaturalImage2DIO
from cv2 import imwrite
from time import time
if __name__ == "__main__":
tic = time()
# instantiate the nnUNetPredictor
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_device=False,
# device=device("cuda", 0),
device=device("cpu"),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=True,
)
predictor.initialize_from_trained_model_folder(
# 直接使用绝对路径,替换join方法
model_training_output_dir="/xxx/nnUNet/nnUNet_results/xxx/nnUNetTrainer__nnUNetPlans__2d",
use_folds=(0,),
checkpoint_name="checkpoint_best.pth",
)
# 支持不以_0000结尾的文件名
image_path = "./test.png"
img, props = NaturalImage2DIO().read_images([image_path])
img = np.array(img, dtype=np.float32)
# print(img.shape)
ret = predictor.predict_single_npy_array(
input_image=img,
image_properties=props,
segmentation_previous_stage=None,
output_file_truncated=None,
save_or_return_probabilities=False
)
image_pred_path = "test_pred.png"
ret = ret.astype(np.uint8)
# print(f"==>> ret shape: {ret.shape}")
# print(f"==>> ret type: {type(ret)}")
ret = ret.transpose((1, 2, 0))
# 恢复自定义的标签值
predict_recover_value_dict = {
1: 128,
2: 196,
3: 255,
}
for predict_value, recover_value in predict_recover_value_dict.items():
ret[ret == predict_value] = recover_value
imwrite(image_pred_path, ret)
print(f"==>> time cost: {time() - tic}")