出现的问题:
如果是有batch normalzition,或者残差网络层,会出现:
ValueError: Input 0 of node vgg_16/conv1/conv1_1/BatchNorm/cond_1/AssignMovingAvg/Switch was passed float from vgg_16/conv1/conv1_1/BatchNorm/moving_mean:0 incompatible with expected float_ref.
类似问题
则需要在restore模型后加入:
-
# fix batch norm nodes
-
for node
in gd.node:
-
if node.op ==
'RefSwitch':
-
node.op =
'Switch'
-
for index
in xrange(len(node.input)):
-
if
'moving_'
in node.input[index]:
-
node.input[index] = node.input[index] +
'/read'
-
elif node.op ==
'AssignSub':
-
node.op =
'Sub'
-
if
'use_locking'
in node.attr:
del node.attr[
'use_locking']
参考:https://github.com/tensorflow/tensorflow/issues/3628
参考:https://www.cnblogs.com/bonelee/p/8445261.htm
1,这段代码可以加在ckpt2pb里面
2,也可以加在使用pb模型时候
tf.import_graph_def
之前。
附:ckpt2pb.py 参考
-
# -*- coding: utf-8 -*-
-
"""
-
Created on Mon Aug 27 15:26:44 2018
-
-
@author: me
-
"""
-
-
#https://blog.youkuaiyun.com/michael_yt/article/details/74737489
-
#https://blog.youkuaiyun.com/yjl9122/article/details/78341689
-
import tensorflow
as tf
-
import os.path
-
import argparse
-
from tensorflow.python.framework
import graph_util
-
-
MODEL_DIR =
"D:/T_mytest/clsTest/checkpoint/pb"
-
MODEL_NAME =
"frozen_model.pb"
-
-
if
not tf.gfile.Exists(MODEL_DIR):
#创建目录
-
tf.gfile.MakeDirs(MODEL_DIR)
-
-
def freeze_graph(model_folder):
-
checkpoint = tf.train.get_checkpoint_state(model_folder)
#检查目录下ckpt文件状态是否可用
-
input_checkpoint = checkpoint.model_checkpoint_path
#得ckpt文件路径
-
output_graph = os.path.join(MODEL_DIR, MODEL_NAME)
#PB模型保存路径
-
-
#output_node_names = "predictions" #原模型输出操作节点的名字
-
output_node_names =
"predicted_val_top_k,predicted_index_top_k"
#原模型输出操作节点的名字
-
saver = tf.train.import_meta_graph(input_checkpoint +
'.meta', clear_devices=
True)
#得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.
-
-
graph = tf.get_default_graph()
#获得默认的图
-
input_graph_def = graph.as_graph_def()
#返回一个序列化的图代表当前的图
-
-
with tf.Session()
as sess:
-
saver.restore(sess, input_checkpoint)
#恢复图并得到数据
-
#sess.run(tf.global_variables_initializer())
-
# fix batch norm nodes
-
for node
in input_graph_def.node:
-
if node.op ==
'RefSwitch':
-
node.op =
'Switch'
-
for index
in range(len(node.input)):
-
if
'moving_'
in node.input[index]:
-
node.input[index] = node.input[index] +
'/read'
-
elif node.op ==
'AssignSub':
-
node.op =
'Sub'
-
if
'use_locking'
in node.attr:
del node.attr[
'use_locking']
-
-
#print ("predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]})) # 测试读出来的模型是否正确,注意这里传入的是输出 和输入 节点的 tensor的名字,不是操作节点的名字
-
-
output_graph_def = graph_util.convert_variables_to_constants(
#模型持久化,将变量值固定
-
sess,
-
input_graph_def,
-
output_node_names.split(
",")
#如果有多个输出节点,以逗号隔开
-
)
-
with tf.gfile.GFile(output_graph,
"wb")
as f:
#保存模型
-
f.write(output_graph_def.SerializeToString())
#序列化输出
-
print(
"%d ops in the final graph." % len(output_graph_def.node))
#得到当前图有几个操作节点
-
-
#for op in graph.get_operations():
-
#print(op.name, op.values())
-
-
if __name__ ==
'__main__':
-
parser = argparse.ArgumentParser()
-
parser.add_argument(
"--model_folder",default=
"D:/T_mytest/clsTest/checkpoint", type=str, help=
"input ckpt model dir")
#命令行解析,help是提示符,type是输入的类型,
-
# 这里运行程序时需要带上模型ckpt的路径,不然会报 error: too few arguments
-
aggs = parser.parse_args()
-
freeze_graph(aggs.model_folder)
-
# freeze_graph("model/ckpt") #模型目录