【Paddle2ONNX】为Paddle2ONNX添加裁剪模型功能

本文介绍了prune_paddle_model.py脚本的更新,扩展了对Paddle2.6.0API的支持,新增了处理模型中输入和输出节点修改的功能,包括读取模型、删除/插入操作等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

简介

原先的prune_paddle_model.py脚本只支持修改Paddle模型的输出节点,这里将Paddle相关API升级到2.6.0,同时添加了对修改输出节点的支持。

核心代码

先读取模型,这没什么好说的,但是要注意的是,针对PaddlePaddle2.6,load_inference_model返回的输入为输入节点的名称,输出则直接是输出节点。

paddle.enable_static()
print("Start to load paddle model...")
exe = static.Executor(paddle.CPUPlace())
[program, feed_target_names, fetch_targets] = static.io.load_inference_model(
    args.model_dir,
    exe,
    model_filename=args.model_filename,
    params_filename=args.params_filename)

Paddle模型中op_type为feed时表示改节点为输入节点。针对输入节点,我们要做的是,先删除再插入。

def prepend_feed_ops(program, feed_target_names):
    if len(feed_target_names) == 0:
        return

    global_block = program.global_block()
    feed_var = global_block.create_var(
        name='feed',
        type=core.VarDesc.VarType.FEED_MINIBATCH,
        persistable=True)

    for i, name in enumerate(feed_target_names):
        if not global_block.has_var(name):
            print(
                "The input[{i}]: '{name}' doesn't exist in pruned inference program, which will be ignored in new saved model.".
                format(
                    i=i, name=name))
            continue
        out = global_block.var(name)
        global_block._prepend_op(
            type='feed',
            inputs={'X': [feed_var]},
            outputs={'Out': [out]},
            attrs={'col': i})

def insert_by_op_type(program, op_names, op_type):
    global_block = program.global_block()
    need_to_remove_op_index = list()
    for i, op in enumerate(global_block.ops):
        if op.type == op_type:
            need_to_remove_op_index.append(i)
    for index in need_to_remove_op_index[::-1]:
        global_block._remove_op(index)
    program.desc.flush()

    if op_type == "feed":
        prepend_feed_ops(program, op_names)
    else:
        append_fetch_ops(program, op_names)

if args.input_names is not None:
    insert_by_op_type(program, args.input_names, 'feed')
    feed_vars = [program.global_block().var(name) for name in args.input_names]
else:
    feed_vars = [program.global_block().var(name) for name in feed_target_names]

op_type为fetch时,表示该节点为输出节点。对于输出节点我们要做的是先删除再添加。

def append_fetch_ops(program, fetch_target_names):
    """
    In this palce, we will add the fetch op
    """
    global_block = program.global_block()
    fetch_var = global_block.create_var(
        name='fetch',
        type=core.VarDesc.VarType.FETCH_LIST,
        persistable=True)
    print("the len of fetch_target_names:%d" % (len(fetch_target_names)))
    for i, name in enumerate(fetch_target_names):
        global_block.append_op(
            type='fetch',
            inputs={'X': [name]},
            outputs={'Out': [fetch_var]},
            attrs={'col': i})

def insert_by_op_type(program, op_names, op_type):
    global_block = program.global_block()
    need_to_remove_op_index = list()
    for i, op in enumerate(global_block.ops):
        if op.type == op_type:
            need_to_remove_op_index.append(i)
    for index in need_to_remove_op_index[::-1]:
        global_block._remove_op(index)
    program.desc.flush()

    if op_type == "feed":
        prepend_feed_ops(program, op_names)
    else:
        append_fetch_ops(program, op_names)

if args.output_names is not None:
    insert_by_op_type(program, args.output_names, 'fetch')
    fetch_vars = [program.global_block().var(out_name) for out_name in args.output_names]
else:
    fetch_vars = [out_var for out_var in fetch_targets]

参考文档

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值