LYT-Net——轻量级网络低光照条件图像修复模型推理部署(C++/Python)

1.环境安装

conda create -n LYT_Torch python=3.9 -y
conda activate LYT_Torch
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

pip install matplotlib scikit-learn scikit-image opencv-python yacs joblib natsort h5py tqdm tensorboard

pip install einops gdown addict future lmdb numpy pyyaml requests scipy yapf lpips thop timm

2. Onnx模型推理

2.1 Onnxruntime

ONNX Runtime(Open Neural Network Exchange Runtime)是一个用于推理(inference)的高性能、跨平台的开源引擎,它由微软公司发起并维护。ONNX Runtime 支持多种机器学习模型,包括但不限于 PyTorch、TensorFlow、SciKit-Learn 等,允许开发者在不同框架之间无缝迁移模型,并在生产环境中高效运行。

  • 跨平台:支持 Windows、Linux 和 macOS 等多种操作系统。
  • 高性能:优化了模型的推理速度,特别是在使用 GPU 加速时。
  • 多框架支持:原生支持 ONNX 格式的模型,同时也支持从其他框架转换来的模型。
  • 实时性:适用于需要实时或近实时推理的应用场景。
  • 易于集成:提供了 C++ 和 Python API,方便集成到不同语言的应用程序中。
  • 自动混合精度:支持自动混合精度(AMP),可以在不损失模型精度的情况下提高推理速度。
  • 量化支持:支持模型量化,可以进一步减小模型大小并提高推理速度。
  • 微服务:支持作为微服务部署,易于在云环境或边缘设备上使用。

使用 ONNX Runtime 进行模型推理的基本步骤如下:

  • 转换模型:如果模型不是 ONNX 格式,需要使用相应的转换工具将其转换为 ONNX 格式。
  • 加载模型:使用 ONNX Runtime 的 API 加载转换后的模型。
  • 准备输入:准备模型推理所需的输入数据,并确保其符合模型的输入要求。
  • 运行推理:调用 ONNX Runtime 的推理 API 运行模型并获取输出。
  • 处理输出:根据应用需求处理模型的推理输出。

2.2 C++推理

#include "LYTNet.h"

LYTNet::LYTNet(std::string model_path)
{
	OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);   ///cuda

	sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
	std::wstring widestr = std::wstring(model_path.begin(), model_path.end()); windows
	ort_session = new Ort::Session(env, widestr.c_str(), sessionOptions); windows*/
	//ort_session = new Session(env, model_path.c_str(), sessionOptions); linux

	size_t numInputNodes = ort_session->GetInputCount();
	size_t numOutputNodes = ort_session->GetOutputCount();
	Ort::AllocatorWithDefaultOptions allocator;
	for (int i = 0; i < numInputNodes; i++)
	{
		input_names.push_back(ort_session->GetInputName(i, allocator));
		Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);
		auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
		auto input_dims = input_tensor_info.GetShape();
		input_node_dims.push_back(input_dims);
	}
	for (int i = 0; i < numOutputNodes; i++)
	{
		output_names.push_back(ort_session->GetOutputName(i, allocator));
		Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
		auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
		auto output_dims = output_tensor_info.GetShape();
		output_node_dims.push_back(output_dims);
	}
	this->inpHeight = input_node_dims[0][1];
	this->inpWidth = input_node_dims[0][2];
	this->outHeight = output_node_dims[0][1];
	this->outWidth = output_node_dims[0][2];
}


/***************** Mat转vector **********************/
template<typename _Tp>
std::vector<_Tp> convertMat2Vector(const cv::Mat& mat)
{
	return (std::vector<_Tp>)(mat.reshape(1, 1));//通道数不变,按行转为一行
}

/****************** vector转Mat *********************/
template<typename _Tp>
cv::Mat convertVector2Mat(std::vector<_Tp> v, int channels, int rows)
{
	cv::Mat mat = cv::Mat(v).clone();//将vector变成单列的mat,这里需要clone(),因为这里的赋值操作是浅拷贝
	cv::Mat dest = mat.reshape(channels, rows);
	return dest;
}


cv::Mat LYTNet::detect(cv::Mat srcimg)
{
	cv::Mat dstimg;
	resize(srcimg, dstimg, cv::Size(this->inpWidth, this->inpHeight));
	dstimg.convertTo(dstimg, CV_32FC3, 1 / 127.5, -1.0);
	this->input_image_ = (std::vector<float>)(dstimg.reshape(1, 1));

	// const size_t area = this->inpWidth * this->inpHeight * 3;
	// this->input_image_.resize(area);
	// memcpy(this->input_image_.data(), (float*)dstimg.data, area*sizeof(float));

	std::array<int64_t, 4> input_shape_{ 1, this->inpHeight, this->inpWidth, 3 };

	auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
	Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info, input_image_.data(),
		input_image_.size(), input_shape_.data(), input_shape_.size());

	std::vector<Ort::Value> ort_outputs = ort_session->Run(Ort::RunOptions{ nullptr }, &input_names[0],
		&input_tensor_, 1, output_names.data(), output_names.size());  

	float* pred = ort_outputs[0].GetTensorMutableData<float>();
	cv::Mat output_image(outHeight, outWidth, CV_32FC3, pred);
	output_image = (output_image + 1.0) * 127.5;
	output_image.convertTo(output_image, CV_8UC3);
	resize(output_image, output_image, cv::Size(srcimg.cols, srcimg.rows));
	return output_image;
}

2.3 Python推理部署

import argparse
import cv2
import numpy as np
import onnxruntime

class LYTNet:
    def __init__(self, modelpath):
        # Initialize model
        # self.net = cv2.dnn.readNet(modelpath) ####opencv-dnn读取失败
        so = onnxruntime.SessionOptions()
        so.log_severity_level = 3
        self.net = onnxruntime.InferenceSession(modelpath, so)
        self.input_height = self.net.get_inputs()[0].shape[1]   ####(1,h,w,3)
        self.input_width = self.net.get_inputs()[0].shape[2]
        self.input_name = self.net.get_inputs()[0].name

    def detect(self, srcimg): 
        input_image = cv2.resize(srcimg, (self.input_width, self.input_height))
        input_image = input_image.astype(np.float32) / 127.5 - 1.0
        blob = np.expand_dims(input_image, axis=0).astype(np.float32)

        result = self.net.run(None, {self.input_name: blob})

        # Post process:squeeze, RGB->BGR, Transpose, uint8 cast
        output_image = np.squeeze(result[0])
        output_image = (output_image + 1.0 ) * 127.5
        output_image = output_image.astype(np.uint8)
        output_image = cv2.resize(output_image, (srcimg.shape[1], srcimg.shape[0]))
        return output_image


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--imgpath', type=str,
                        default='testimgs/4_1.JPG', help="image path")
    parser.add_argument('--modelpath', type=str,
                        default='weights/lyt_net_lolv2_real_320x240.onnx', help="model path")
    args = parser.parse_args()

    mynet = LYTNet(args.modelpath)
    srcimg = cv2.imread(args.imgpath)

    dstimg = mynet.detect(srcimg)

    if srcimg.shape[0] >= srcimg.shape[1]:
        boundimg = np.zeros((srcimg.shape[0], 10, 3), dtype=srcimg.dtype)+255
        combined_img = np.hstack([srcimg, boundimg, dstimg])
    else:
        boundimg = np.zeros((10, srcimg.shape[1], 3), dtype=srcimg.dtype)+255
        combined_img = np.vstack([srcimg, boundimg, dstimg])
    
    winName = 'Deep Learning use OpenCV-dnn'
    cv2.namedWindow(winName, 0)
    cv2.imshow(winName, combined_img)  ###原图和结果图也可以分开窗口显示
    cv2.waitKey(0)
    cv2.destroyAllWindows()

2.4 实现效果

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

代码与模型下载地址:https://download.youkuaiyun.com/download/matt45m/89600913

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

知来者逆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值