nnunetv2系列:使用onnx模型参数利用onnxruntime推理

nnunetv2系列:使用onnx模型参数利用onnxruntime推理

首先感谢https://blog.youkuaiyun.com/chen_niansan/article/details/142328247作者分享的示例,这里在此基础上进行修改和增加了将预测结果转换到方便人查看的掩码图。

import os
import numpy as np
from cv2 import (
    imread,
    imwrite,
    cvtColor,
    COLOR_BGR2RGB,
    resize,
    INTER_LINEAR,
)
from time import time
from onnxruntime import (
    InferenceSession,
    # get_device,
)


def preprocess_image(image, input_size, in_ch = 3):
    # image = np.array(image)
    image_resized = resize(
        image, input_size, 
        interpolation=INTER_LINEAR
    )
    # print(image_resized.shape)
    image_normalized = image_resized / 255.0
    if len(image_normalized.shape) == 2:
        image_normalized = np.stack(
            [image_normalized] * in_ch, 
            axis=-1
        )
    
    image_normalized = np.transpose(image_normalized, (2, 0, 1))
    image_normalized = np.expand_dims(image_normalized, axis=0)

    return image_normalized.astype(np.float32)


if __name__ == "__main__":
    tic = time()
    # 设置环境变量以指定使用0号GPU
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    image_path = "./img_2_0000.png"
    image_pred_path = "./img_2_0000.png"
    seg_model_path = "/xxx/checkpoint_best_fold_0.onnx"
    if not os.path.exists(seg_model_path):
        raise FileNotFoundError(
            f"Model file {seg_model_path} does not exist."
        )

    # device = get_device()
    device = "CPU"
    providers = (
        ["CUDAExecutionProvider"]
        if device != "CPU"
        else ["CPUExecutionProvider"]
    )

    seg_model = InferenceSession(
        seg_model_path, 
        providers=providers
    )
    image = imread(image_path)
    image = cvtColor(image, COLOR_BGR2RGB)
  
    # 尺寸参考plans.json中的patch_size=(h, w)
    image_input = preprocess_image(
        image,
        (w, h),
        in_ch=3
    )
    # 获取输入和输出的名称
    # input_name = seg_model.get_inputs()[0].name
    # print(f"==>> input_name: {input_name}")
    # output_name = seg_model.get_outputs()[0].name
    # print(f"==>> output_name: {output_name}")
    # 数组组成的列表
    mask_pred_info = seg_model.run(
        # [output_name],
        # {input_name: image_input}
        ['output'],
        {'input': image_input}
    )
    # print(f"==>> mask_pred_info: {mask_pred_info}")
    mask_pred_batch = mask_pred_info[0]
    mask_pred = np.squeeze(
        mask_pred_batch, axis=0
    )
    mask_pred = np.argmax(mask_pred, axis=0)
    # pred_pixel_value_list = np.unique(mask_pred)

    # 恢复自定义的标签值
    predict_recover_value_dict = {
        0: 0,
        1: 128,
        2: 196,
        3: 255
    }
    for predict_value, recover_value in predict_recover_value_dict.items():
        mask_pred[mask_pred == predict_value] = recover_value
        
    imwrite(image_pred_path, mask_pred)
    print("time cost is: ", time() - tic)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值