超实用!LaMa推理优化:ONNX导出与TensorRT加速实践指南
【免费下载链接】lama 项目地址: https://gitcode.com/gh_mirrors/lam/lama
你是否在使用LaMa进行图像修复时遇到推理速度慢的问题?特别是处理高分辨率图像时,等待时间过长影响工作效率?本文将详细介绍如何通过ONNX导出与TensorRT加速技术,显著提升LaMa模型的推理性能,让你在保持修复质量的同时,享受飞一般的速度体验。读完本文,你将掌握模型导出、优化和部署的全流程,轻松应对大规模图像修复任务。
LaMa模型简介
LaMa(Resolution-robust Large Mask Inpainting with Fourier Convolutions)是一款基于傅里叶卷积的高分辨率图像修复模型。它在训练时使用256x256的图像,但却能惊人地泛化到高达2k的分辨率,在周期性结构等具有挑战性的场景中也能实现出色的修复效果。
LaMa的核心优势在于其采用的傅里叶卷积技术,这使得模型能够有效捕捉图像中的全局结构信息。模型的整体架构基于ResNet,包含下采样、瓶颈层和上采样三个主要部分。其中,瓶颈层使用了多个残差块,部分配置中还引入了FFC(Fourier Filter Convolution)块来增强模型的特征提取能力。

LaMa的原始实现使用PyTorch框架,其推理代码主要位于saicinpainting/training/trainers/default.py文件中。该文件定义了DefaultInpaintingTrainingModule类,负责模型的前向传播和损失计算。在推理过程中,模型接收图像和掩码作为输入,输出修复后的图像。
环境准备与模型获取
在开始优化之前,我们需要先准备好必要的环境并获取LaMa预训练模型。以下是详细步骤:
1. 克隆项目仓库
首先,克隆LaMa项目的Git仓库:
git clone https://gitcode.com/gh_mirrors/lam/lama
cd lama
2. 创建虚拟环境
推荐使用conda创建独立的虚拟环境,以避免依赖冲突:
conda env create -f conda_env.yml
conda activate lama
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -y
pip install pytorch-lightning==1.2.9
3. 下载预训练模型
LaMa提供了多个预训练模型,我们使用性能最佳的big-lama模型:
curl -LJO https://huggingface.co/smartywu/big-lama/resolve/main/big-lama.zip
unzip big-lama.zip
下载完成后,模型文件将保存在当前目录下的big-lama文件夹中。
ONNX模型导出
ONNX(Open Neural Network Exchange)是一种开放的模型格式,支持多种深度学习框架。将PyTorch模型导出为ONNX格式,不仅可以实现跨框架部署,还能为后续的TensorRT优化做好准备。
1. 导出准备
LaMa原始代码中并未直接提供ONNX导出功能,因此我们需要编写一个简单的导出脚本。首先,创建一个名为export_onnx.py的文件,然后导入必要的库和模块:
import torch
import yaml
import os
from saicinpainting.training.modules.pix2pixhd import GlobalGenerator
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2. 加载模型配置和权重
接下来,我们需要加载模型的配置文件和预训练权重。LaMa使用yaml格式的配置文件,我们可以从configs/training/big-lama.yaml中获取模型结构参数:
# 加载配置文件
config_path = "configs/training/big-lama.yaml"
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# 创建模型
generator_config = config['generator']
model = GlobalGenerator(
input_nc=generator_config['input_nc'],
output_nc=generator_config['output_nc'],
ngf=generator_config['ngf'],
n_downsampling=generator_config['n_downsampling'],
n_blocks=generator_config['n_blocks'],
norm_layer=torch.nn.BatchNorm2d,
padding_type=generator_config['padding_type'],
ffc_positions=generator_config.get('ffc_positions', None),
ffc_kwargs=generator_config.get('ffc_kwargs', {})
).to(device)
# 加载预训练权重
checkpoint_path = "big-lama/last.ckpt"
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
3. 导出ONNX模型
现在我们可以将模型导出为ONNX格式了。需要注意的是,LaMa模型的输入是图像和掩码的拼接,因此我们需要创建一个合适的输入张量:
# 创建输入张量
input_nc = generator_config['input_nc'] # 通常为4 (3通道图像 + 1通道掩码)
height, width = 512, 512 # 可以根据需要调整
dummy_input = torch.randn(1, input_nc, height, width, device=device)
# 导出ONNX模型
onnx_path = "big-lama.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_path,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {2: 'height', 3: 'width'},
'output': {2: 'height', 3: 'width'}
}
)
print(f"ONNX模型已导出至: {onnx_path}")
4. 导出注意事项
在导出过程中,可能会遇到一些问题,这里提供几个解决方案:
-
动态输入尺寸:通过设置
dynamic_axes参数,允许模型接受不同尺寸的输入图像。 -
不支持的操作:如果遇到PyTorch操作不支持ONNX导出的情况,可以尝试修改模型代码或降低ONNX的opset版本。
-
模型简化:导出后可以使用ONNX Simplifier工具简化模型,去除冗余操作:
pip install onnx-simplifier
python -m onnxsim big-lama.onnx big-lama-sim.onnx
TensorRT加速
TensorRT是NVIDIA开发的高性能深度学习推理SDK,能够显著提高模型在GPU上的推理速度。下面我们将介绍如何使用TensorRT优化导出的ONNX模型。
1. TensorRT环境搭建
首先,确保已安装TensorRT。可以通过以下命令安装TensorRT Python包:
pip install tensorrt
注意:TensorRT版本需要与CUDA版本匹配,具体安装方法请参考NVIDIA官方文档。
2. ONNX模型转换为TensorRT引擎
使用TensorRT的Python API,我们可以将ONNX模型转换为优化的TensorRT引擎:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
# 创建Logger
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
# 构建TensorRT引擎
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)
# 解析ONNX模型
onnx_path = "big-lama-sim.onnx"
with open(onnx_path, 'rb') as model_file:
parser.parse(model_file.read())
# 设置构建配置
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
config.set_flag(trt.BuilderFlag.FP16) # 启用FP16精度
# 构建并保存引擎
serialized_engine = builder.build_serialized_network(network, config)
engine_path = "big-lama.engine"
with open(engine_path, "wb") as f:
f.write(serialized_engine)
print(f"TensorRT引擎已保存至: {engine_path}")
3. TensorRT推理代码实现
创建一个使用TensorRT引擎进行推理的函数:
import numpy as np
import pycuda.driver as cuda
class TensorRTInfer:
def __init__(self, engine_path):
self.engine_path = engine_path
self.trt_logger = trt.Logger(trt.Logger.WARNING)
self.runtime = trt.Runtime(self.trt_logger)
self.engine = self.load_engine()
self.context = self.engine.create_execution_context()
self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers()
def load_engine(self):
with open(self.engine_path, 'rb') as f:
engine_data = f.read()
return self.runtime.deserialize_cuda_engine(engine_data)
def allocate_buffers(self):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in self.engine:
size = trt.volume(self.engine.get_binding_shape(binding)) * self.engine.max_batch_size
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
# 分配主机和设备缓冲区
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# 将设备缓冲区地址添加到绑定列表
bindings.append(int(device_mem))
# Append to the appropriate list.
if self.engine.binding_is_input(binding):
inputs.append({'host': host_mem, 'device': device_mem})
else:
outputs.append({'host': host_mem, 'device': device_mem})
return inputs, outputs, bindings, stream
def infer(self, input_image):
# 将输入数据复制到主机缓冲区
np.copyto(self.inputs[0]['host'], input_image.ravel())
# 将数据从主机复制到设备
cuda.memcpy_htod_async(self.inputs[0]['device'], self.inputs[0]['host'], self.stream)
# 执行推理
self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream.handle)
# 将结果从设备复制到主机
cuda.memcpy_dtoh_async(self.outputs[0]['host'], self.outputs[0]['device'], self.stream)
# 同步流
self.stream.synchronize()
# 返回结果
return self.outputs[0]['host'].reshape(input_image.shape[0], 3, input_image.shape[2], input_image.shape[3])
4. 推理性能对比
为了验证TensorRT加速效果,我们可以对比PyTorch、ONNX Runtime和TensorRT三种推理方式的性能:
import time
import torch
import numpy as np
# 创建测试输入
input_size = (1, 4, 512, 512) # (batch, channels, height, width)
input_data = np.random.rand(*input_size).astype(np.float32)
# PyTorch推理
torch_model = ... # 加载PyTorch模型
torch_input = torch.from_numpy(input_data).to(device)
torch_model.eval()
with torch.no_grad():
start_time = time.time()
torch_output = torch_model(torch_input)
torch_time = time.time() - start_time
print(f"PyTorch推理时间: {torch_time:.4f}秒")
# ONNX Runtime推理
import onnxruntime as ort
ort_session = ort.InferenceSession("big-lama-sim.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
start_time = time.time()
ort_outputs = ort_session.run(None, ort_inputs)
ort_time = time.time() - start_time
print(f"ONNX Runtime推理时间: {ort_time:.4f}秒")
# TensorRT推理
trt_infer = TensorRTInfer("big-lama.engine")
start_time = time.time()
trt_output = trt_infer.infer(input_data)
trt_time = time.time() - start_time
print(f"TensorRT推理时间: {trt_time:.4f}秒")
# 计算加速比
print(f"TensorRT相对PyTorch加速比: {torch_time / trt_time:.2f}x")
print(f"TensorRT相对ONNX Runtime加速比: {ort_time / trt_time:.2f}x")
通常情况下,TensorRT能够比原生PyTorch实现2-5倍的加速,具体加速效果取决于模型结构和硬件配置。
实际应用与优化建议
1. 批处理推理
对于大规模图像修复任务,可以使用批处理推理进一步提高效率。修改TensorRT引擎构建代码,设置最大批处理大小:
builder.max_batch_size = 8 # 设置最大批处理大小
然后在推理时传入批量图像数据。
2. 精度优化
TensorRT支持多种精度模式,包括FP32、FP16和INT8。在不影响修复质量的前提下,可以使用FP16或INT8精度进一步提高推理速度:
# FP16精度
config.set_flag(trt.BuilderFlag.FP16)
# INT8精度 (需要校准数据集)
config.set_flag(trt.BuilderFlag.INT8)
config.int8_calibrator = Int8Calibrator(calibration_data)
3. 多流推理
对于实时应用场景,可以使用多流推理充分利用GPU资源:
# 创建多个执行上下文
contexts = [engine.create_execution_context() for _ in range(num_streams)]
# 为每个流分配独立的缓冲区和流对象
4. 模型量化
除了使用TensorRT的INT8模式,还可以在PyTorch中对模型进行量化,然后导出为ONNX格式:
# 动态量化
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Conv2d}, dtype=torch.qint8
)
注意:量化可能会导致模型精度损失,需要在速度和精度之间权衡。
总结与展望
本文详细介绍了如何将LaMa模型导出为ONNX格式并使用TensorRT进行加速。通过这些优化步骤,我们可以显著提高LaMa模型的推理速度,使其能够更好地满足实际应用需求。
未来,我们还可以探索更多优化方向,如模型剪枝、知识蒸馏等,进一步提升LaMa的推理性能。同时,随着硬件技术的发展,新的加速技术和优化方法也将不断涌现,为图像修复任务带来更高的效率和更好的体验。
希望本文能够帮助你顺利实现LaMa模型的推理优化,如果你有任何问题或优化建议,欢迎在评论区留言讨论。
最后,附上本文涉及的主要代码文件和工具:
- 模型导出脚本:export_onnx.py
- TensorRT推理代码:trt_infer.py
- 性能测试脚本:benchmark.py
祝你在图像修复的道路上越走越远,创造出更多精彩的应用!

通过本文介绍的ONNX导出与TensorRT加速技术,你可以将LaMa模型的推理速度提升数倍,轻松应对各种图像修复任务。无论是大规模批量处理还是实时应用场景,这些优化技术都能为你带来显著的效率提升。现在就动手尝试,体验LaMa模型的极速推理吧!
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



