简介
原先的prune_paddle_model.py脚本只支持修改Paddle模型的输出节点,这里将Paddle相关API升级到2.6.0,同时添加了对修改输出节点的支持。
核心代码
先读取模型,这没什么好说的,但是要注意的是,针对PaddlePaddle2.6,load_inference_model返回的输入为输入节点的名称,输出则直接是输出节点。
paddle.enable_static()
print("Start to load paddle model...")
exe = static.Executor(paddle.CPUPlace())
[program, feed_target_names, fetch_targets] = static.io.load_inference_model(
args.model_dir,
exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
Paddle模型中op_type为feed时表示改节点为输入节点。针对输入节点,我们要做的是,先删除再插入。
def prepend_feed_ops(program, feed_target_names):
if len(feed_target_names) == 0:
return
global_block = program.global_block()
feed_var = global_block.create_var(
name='feed',
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
for i, name in enumerate(feed_target_names):
if not global_block.has_var(name):
print(
"The input[{i}]: '{name}' doesn't exist in pruned inference program, which will be ignored in new saved model.".
format(
i=i, name=name))
continue
out = global_block.var(name)
global_block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
def insert_by_op_type(program, op_names, op_type):
global_block = program.global_block()
need_to_remove_op_index = list()
for i, op in enumerate(global_block.ops):
if op.type == op_type:
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
program.desc.flush()
if op_type == "feed":
prepend_feed_ops(program, op_names)
else:
append_fetch_ops(program, op_names)
if args.input_names is not None:
insert_by_op_type(program, args.input_names, 'feed')
feed_vars = [program.global_block().var(name) for name in args.input_names]
else:
feed_vars = [program.global_block().var(name) for name in feed_target_names]
op_type为fetch时,表示该节点为输出节点。对于输出节点我们要做的是先删除再添加。
def append_fetch_ops(program, fetch_target_names):
"""
In this palce, we will add the fetch op
"""
global_block = program.global_block()
fetch_var = global_block.create_var(
name='fetch',
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
print("the len of fetch_target_names:%d" % (len(fetch_target_names)))
for i, name in enumerate(fetch_target_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
def insert_by_op_type(program, op_names, op_type):
global_block = program.global_block()
need_to_remove_op_index = list()
for i, op in enumerate(global_block.ops):
if op.type == op_type:
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
global_block._remove_op(index)
program.desc.flush()
if op_type == "feed":
prepend_feed_ops(program, op_names)
else:
append_fetch_ops(program, op_names)
if args.output_names is not None:
insert_by_op_type(program, args.output_names, 'fetch')
fetch_vars = [program.global_block().var(out_name) for out_name in args.output_names]
else:
fetch_vars = [out_var for out_var in fetch_targets]