在 ONNX Runtime 中,io_binding()
是一种高性能数据绑定机制,允许直接将输入/输出张量绑定到特定的内存(如 GPU 显存),避免不必要的内存拷贝,从而提升推理速度。
a. 获取 io_binding
对象的方法
步骤 1:创建 ONNX Runtime 会话(Session)
首先需要加载模型并创建推理会话,指定执行提供器(如 CUDA EP):
import onnxruntime as ort # 创建会话,启用 CUDA 执行提供器 session = ort.InferenceSession( "model.onnx", providers=["CUDAExecutionProvider"] # 或 ["CPUExecutionProvider"] )
步骤 2:从会话中获取 IoBinding
对象
通过 session.io_binding()
直接创建绑定对象:
io_binding = session.io_binding() # 关键步骤:获取 io_binding 对象
b.基本使用步骤
(1) 创建 IoBinding
对象
import onnxruntime as ort session = ort.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"]) io_binding = session.io_binding()
(2) 绑定输入张量
-
示例:绑定 GPU 显存中的输入
import numpy as np import torch # 假设输入是 PyTorch CUDA 张量(或 NumPy 数组) input_data = torch.randn(1, 3, 224, 224).cuda() # 显存中的张量 # 将张量绑定到 IO Binding io_binding.bind_input( name="input_name", # 模型的输入节点名称 device_type="cuda", # 设备类型('cuda' 或 'cpu') device_id=0, # GPU 设备 ID element_type=np.float32, shape=input_data.shape, buffer_ptr=input_data.data_ptr() # 显存指针 )
(3) 绑定输出张量
-
示例:预分配 GPU 显存作为输出
# 预分配输出显存(需知道输出形状) output_shape = (1, 1000) # 假设模型输出形状 output_buffer = torch.empty(output_shape, dtype=torch.float32).cuda() io_binding.bind_output( name="output_name", # 模型的输出节点名称 device_type="cuda", device_id=0, element_type=np.float32, shape=output_shape, buffer_ptr=output_buffer.data_ptr() )
(4) 执行推理
session.run_with_iobinding(io_binding)
(5) 获取输出结果
# 输出数据已在预分配的 output_buffer 中(GPU 显存) output_result = output_buffer.cpu().numpy() # 如需 CPU 数据,拷贝回主机
c.接口方法
在 ONNX Runtime 中,bind_input
和 bind_ortvalue_input
都是 IoBinding
接口提供的方法,用于绑定输入数据,但它们的设计目标、灵活性和底层实现有显著区别。
bind_input
功能
-
直接将 原始内存指针(如 NumPy 数组、PyTorch 张量的指针)绑定到模型的输入。
-
需要手动指定内存的设备类型(CPU/GPU)、形状、数据类型等元信息。
适用场景
-
已有现成的内存(如 CUDA 显存中的 PyTorch 张量或 CPU 的 NumPy 数组)。
-
需要精细控制内存布局或避免额外拷贝时。
示例
import torch # 假设 input_tensor 是 GPU 张量 input_tensor = torch.randn(1, 3, 224, 224).cuda() io_binding.bind_input( name="input_name", device_type="cuda", device_id=0, element_type=np.float32, shape=input_tensor.shape, buffer_ptr=input_tensor.data_ptr() # 直接传递显存指针 )
优缺点
-
优点:
-
零拷贝,性能高。
-
支持任意来源的内存(只要提供指针)。
-
-
缺点:
-
需要手动管理所有元信息(
shape
、dtype
、device
)。 -
对动态形状支持较弱(需手动更新形状)。
-
bind_ortvalue_input
功能
-
绑定一个
OrtValue
对象 作为输入。 -
OrtValue
是 ONNX Runtime 提供的高级数据结构,内部封装了内存指针、设备信息、形状和数据类型等元信息。
适用场景
-
使用 ONNX Runtime 内置的
OrtValue
对象管理数据(例如从其他 ORT 接口返回的数据)。 -
需要更简洁的 API 或动态形状支持时。
示例
import numpy as np import onnxruntime as ort # 创建 OrtValue(从 NumPy 数组) input_np = np.random.randn(1, 3, 224, 224).astype(np.float32) ort_value = ort.OrtValue.ortvalue_from_numpy(input_np, "cuda", 0) # 自动分配 GPU 显存 # 绑定 OrtValue io_binding.bind_ortvalue_input("input_name", ort_value)
优缺点
-
优点:
-
更简洁,自动管理元信息(无需手动指定
shape
、dtype
)。 -
支持动态形状(
OrtValue
内部自动更新)。 -
与 ONNX Runtime 其他接口(如
run()
)无缝兼容。
-
-
缺点:
-
需要额外创建
OrtValue
对象(轻微开销)。 -
对非 ORT 管理的内存(如 PyTorch 张量)需要转换。
-
核心区别对比
特性 | bind_input | bind_ortvalue_input |
---|---|---|
输入类型 | 原始内存指针(data_ptr ) | OrtValue 对象 |
设备类型指定 | 需手动指定(device_type="cuda" ) | 由 OrtValue 内部管理 |
形状/数据类型 | 需手动指定 | 自动从 OrtValue 获取 |
动态形状支持 | 需手动更新 shape | 自动更新(OrtValue 内部处理) |
内存来源 | 支持任意内存(PyTorch、NumPy 等) | 需通过 ORT 接口创建(如 ortvalue_from_numpy ) |
性能 | 零拷贝,最高效 | 轻微开销(OrtValue 封装层) |
代码简洁性 | 较冗长 | 更简洁 |
如何选择?
-
优先用
bind_ortvalue_input
如果:-
数据已经是
OrtValue
或来自 ONNX Runtime 其他接口。 -
需要代码简洁性或动态形状支持。
-
不介意轻微的性能开销(如
OrtValue
创建成本)。
-
-
优先用
bind_input
如果:-
需要直接绑定 PyTorch/CUDA 等外部内存(避免额外拷贝)。
-
对性能有极致要求(如高频推理场景)。
-
能手动管理内存元信息(
shape
、dtype
)。
-
完整示例对比
(1) 使用 bind_input
(原始指针)
# 准备 PyTorch CUDA 张量 input_tensor = torch.randn(1, 3, 224, 224).cuda() # 绑定输入(显式指定所有参数) io_binding.bind_input( name="input", device_type="cuda", device_id=0, element_type=np.float32, shape=input_tensor.shape, buffer_ptr=input_tensor.data_ptr() )
(2) 使用 bind_ortvalue_input
(OrtValue)
# 将 NumPy 数组转换为 OrtValue(自动分配 GPU 显存) input_np = np.random.randn(1, 3, 224, 224).astype(np.float32) ort_value = ort.OrtValue.ortvalue_from_numpy(input_np, "cuda", 0) # 绑定输入(无需手动指定元信息) io_binding.bind_ortvalue_input("input", ort_value)
注意事项
-
内存生命周期:
-
使用
bind_input
时,需确保绑定的内存指针在推理完成前有效(如 PyTorch 张量未被释放)。 -
OrtValue
会自行管理内存,但需注意作用域。
-
-
设备一致性:
-
确保绑定的设备类型(CPU/GPU)与模型执行的 EP(如 CUDA EP)匹配。
-
-
错误排查:
-
如果绑定失败,检查
shape
和dtype
是否与模型输入节点一致。
-
d. onnx高级内存管理机制示例
output_names = [output.name for output in context.get_outputs()]
-
context
:
通常是 ONNX Runtime 的InferenceSession
对象(模型会话),通过ort.InferenceSession("model.onnx")
创建。 -
context.get_outputs()
:
ONNX Runtime 的方法,返回模型的所有输出节点信息(列表形式,每个元素是一个onnxruntime.NodeArg
对象)。 -
output.name
:
每个NodeArg
对象的name
属性,表示该输出节点的名称(字符串)。 -
列表推导式:
遍历get_outputs()
返回的所有输出节点,提取它们的名称,最终生成一个名称列表。
for o_n in output_names:
io_binding.bind_output( o_n, "pinned")
在bind_output参数中没有指定buffer_ptr,是因为它使用了 ONNX Runtime 的 高级内存管理机制,
关键原因:"pinned"
参数的作用
当调用 bind_output(out, "pinned")
时:
-
"pinned"
是一个 特殊标识符,告诉 ONNX Runtime:-
自动分配固定内存(Pinned Memory):
ONNX Runtime 内部会为输出张量分配一块 页锁定内存(Page-Locked Memory),而不是要求用户手动传入指针。 -
内存由 ORT 管理:
这块内存的生命周期由 ONNX Runtime 控制,用户无需手动释放。
-
与显式绑定指针(buffer_ptr
)的区别
绑定方式 | 用户是否提供指针 | 内存管理方 | 适用场景 |
---|---|---|---|
bind_output(name, "pinned") | ❌ 无需 | ONNX Runtime 自动管理 | 快速测试、无需精细控制内存时 |
bind_output(name, device_type, buffer_ptr) | ✅ 需手动传入指针 | 用户自行管理 | 高性能场景、需 |