在使用freeze_gragh.py时,需要输入参数 --output_node_name,可以在原有网络结构的代码后面直接添加需要的节点,然后在代码中print对应的tensor即可,比如:
# my model :
...
with tf.name_scope("scope_output_node_name"):
final_output_ids = infer_outputs.predicted_ids[:, :, 0]
final_score = final_state[1][0][0]
print(final_output_ids)
print(final_score)
结果显示两个节点的名称为:
scope_final_output_ids/strided_slice
scope_final_output_ids/strided_slice_2
最后,可以使用freeze_graph.py如下:
python3 freeze_graph.py --input_graph=“models/gqa.batch2.pbtxt” --input_checkpoint=“models/batch_model_epoch0_val2.53_tst0.00.ckpt” --output_node_names=“scope_final_output_ids/strided_slice, scope_final_output_ids/strided_slice_2” --output_graph=“gqa.batch.ckpt.pb”