修改onnx模型中间节点命名(包含输入、输出重命名)

该文章已生成可运行项目,

来源:Paddle2ONNX

Paddle2ONNX/tools/onnx/README.md at develop · PaddlePaddle/Paddle2ONNX · GitHub

依赖:import onnx

python rename_onnx_model.py --model model.onnx --origin_names x y z --new_names x1 y1 z1 --save_file new_model.onnx

其中 origin_names 和 new_names,前者表示原模型中各个命名(可指定多个),后者表示新命名,两个参数指定的命名个数需要相同 

 

import argparse
import sys


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model',
        required=True,
        help='Path of directory saved the input model.')
    parser.add_argument(
        '--origin_names',
        required=True,
        nargs='+',
        help='The original name you want to modify.')
    parser.add_argument(
        '--new_names',
        required=True,
        nargs='+',
        help='The new name you want change to, the number of new_names should be same with the number of origin_names'
    )
    parser.add_argument(
        '--save_file', required=True, help='Path to save the new onnx model.')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_arguments()
    import onnx
    model = onnx.load(args.model)
    output_tensor_names = set()
    for ipt in model.graph.input:
        output_tensor_names.add(ipt.name)
    for node in model.graph.node:
        for out in node.output:
            output_tensor_names.add(out)

    for origin_name in args.origin_names:
        if origin_name not in output_tensor_names:
            print("[ERROR] Cannot find tensor name '{}' in onnx model graph.".
                  format(origin_name))
            sys.exit(-1)
    if len(set(args.origin_names)) < len(args.origin_names):
        print(
            "[ERROR] There's dumplicate name in --origin_names, which is not allowed."
        )
        sys.exit(-1)
    if len(args.new_names) != len(args.origin_names):
        print(
            "[ERROR] Number of --new_names must be same with the number of --origin_names."
        )
        sys.exit(-1)
    if len(set(args.new_names)) < len(args.new_names):
        print(
            "[ERROR] There's dumplicate name in --new_names, which is not allowed."
        )
        sys.exit(-1)
    for new_name in args.new_names:
        if new_name in output_tensor_names:
            print(
                "[ERROR] The defined new_name '{}' is already exist in the onnx model, which is not allowed."
            )
            sys.exit(-1)

    for i, ipt in enumerate(model.graph.input):
        if ipt.name in args.origin_names:
            idx = args.origin_names.index(ipt.name)
            model.graph.input[i].name = args.new_names[idx]

    for i, node in enumerate(model.graph.node):
        for j, ipt in enumerate(node.input):
            if ipt in args.origin_names:
                idx = args.origin_names.index(ipt)
                model.graph.node[i].input[j] = args.new_names[idx]
        for j, out in enumerate(node.output):
            if out in args.origin_names:
                idx = args.origin_names.index(out)
                model.graph.node[i].output[j] = args.new_names[idx]

    for i, out in enumerate(model.graph.output):
        if out.name in args.origin_names:
            idx = args.origin_names.index(out.name)
            model.graph.output[i].name = args.new_names[idx]

    onnx.checker.check_model(model)
    onnx.save(model, args.save_file)
    print("[Finished] The new model saved in {}.".format(args.save_file))
    print("[DEBUG INFO] The inputs of new model: {}".format(
        [x.name for x in model.graph.input]))
    print("[DEBUG INFO] The outputs of new model: {}".format(
        [x.name for x in model.graph.output]))
修改前

修改后
本文章已经生成可运行项目
<think> 我们之前已经讨论了如何使用ONNX GraphSurgeon修改节点的输入,现在用户的问题是:修改节点输入后如何将计算图导出为ONNX模型。根据之前的引用[^1]和我们的知识,ONNX GraphSurgeon修改图后,需要将图结构导出为ONNX模型格式。 步骤如下: 1. 修改节点输入(如之前所述)。 2. 清理图中未使用的节点和张量(使用`cleanup`方法)。 3. 对节点进行拓扑排序(使用`toposort`方法)以确保节点的执行顺序正确。 4. 将图导出为ONNX模型(使用`export_onnx`函数),然后使用ONNX的保存函数保存为文件。 具体代码示例: 假设我们已经完成了对图中节点输入修改,接下来进行导出: ```python import onnx_graphsurgeon as gs import onnx # 1. 加载模型(假设已经进行了一系列修改) graph = gs.import_onnx(onnx.load("original_model.onnx")) # ... 这里对graph进行修改,比如修改节点输入 ... # 2. 清理图中未使用的节点和张量 graph.cleanup() # 3. 对节点进行拓扑排序 graph.nodes = graph.toposort().nodes # 4. 导出为ONNX模型 onnx_model = gs.export_onnx(graph) # 5. 保存模型 onnx.save(onnx_model, "modified_model.onnx") ``` 注意事项: - 在导出之前,务必进行清理(`cleanup`)和拓扑排序(`toposort`),否则可能导致模型结构错误或导出失败。 - 导出后的模型建议使用ONNX的检查工具(`onnx.checker.check_model`)进行验证。 如果用户需要动态修改模型输入维度(比如修改输入batch size为动态),可以在修改节点输入之前,修改输入张量的形状。例如: ```python # 修改整个图的输入张量(假设图的输入名为"input") for input_tensor in graph.inputs: if input_tensor.name == "input": input_tensor.shape = ["batch_size", 3, 224, 224] # 使用字符串"batch_size"表示动态维度 # 然后进行后续修改和导出 ``` 但是注意,动态维度的设置需要与模型中其他节点的形状推断兼容,否则可能导致形状推断错误。 如果用户修改节点输入后出现形状推断错误,需要检查修改后的张量形状是否与下游节点兼容。可以通过ONNX的形状推断工具(如`onnx.shape_inference.infer_shapes`)对导出的模型进行形状推断,或者使用ONNX Runtime进行验证。 现在,根据要求生成相关问题。 </think>### ONNX GraphSurgeon 修改节点输入后导出 ONNX 模型的方法 完成节点输入修改后,导出 ONNX 模型需遵循以下步骤: #### 1. **清理无效节点** [^1] ```python graph.cleanup() # 移除所有未被引用的中间张量 ``` #### 2. **拓扑排序重建依赖** ```python graph.nodes = graph.toposort() # 按执行顺序重排节点 ``` #### 3. **导出为 ONNX 模型** ```python modified_onnx_model = gs.export_onnx(graph) # 转换为 ONNX 模型对象 ``` #### 4. **保存模型文件** ```python onnx.save(modified_onnx_model, "modified_model.onnx") ``` #### 5. **完整导出流程示例** ```python import onnx import onnx_graphsurgeon as gs # 加载模型 graph = gs.import_onnx(onnx.load("original_model.onnx")) # 修改节点输入 (示例) target_node = [node for node in graph.nodes if node.name == "target_layer"][0] new_input = gs.Variable(name="new_input", dtype=np.float32, shape=(1,3,224,224)) target_node.inputs[0] = new_input # 导出流程 graph.cleanup() graph.nodes = graph.toposort().nodes onnx_model = gs.export_onnx(graph) onnx.save(onnx_model, "modified_model.onnx") ``` #### 关键注意事项 1. **形状一致性验证** 导出后需检查形状兼容性: ```python onnx.checker.check_model(onnx_model) # 验证模型完整性 ``` 2. **动态维度处理** 若需动态维度,显式设置: ```python new_input.shape = ["batch_size", 3, 224, 224] ``` 3. **数据类型匹配** 新输入张量数据类型需与原始一致: ```python np_dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[onnx.TensorProto.FLOAT] new_input.dtype = np_dtype ``` 4. **节点连接验证** 确保所有输入输出正确连接: ```python for node in graph.nodes: assert all(i is not None for i in node.inputs) # 输入无空值 ``` 导出后建议使用 ONNX Runtime 验证模型可运行性: ```python import onnxruntime as ort ort.InferenceSession("modified_model.onnx") # 成功加载说明导出有效 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值