tensorflow读取pb文件, 出现的ValueError问题

ValueError: Input 0 of node import/MobilenetV2/expanded_conv/depthwise/BatchNorm/cond/Assign/Switch was passed float from import/MobilenetV2/expanded_conv/depthwise/BatchNorm/moving_mean:0 incompatible with expected float_ref.
找到的解决方法:
ckpt转pb,batch normalzition 出现的ValueError问题
好多人说有用,但是我的还没解决
参考:https://www.cnblogs.com/bonelee/p/8445261.html
还是没解决
我的报错代码:
pb_loader.py

...省略...

    def _load_pb(self):
        return_elements = self.input_nodes + self.output_nodes
        input_node_num = len(self.input_nodes)

        with tf.gfile.GFile(self.model_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        self.graph = tf.Graph()


        with self.graph.as_default() as g:



            tensor_nodes = tf.import_graph_def(
                graph_def, return_elements=return_elements)
            self.input_tensors = tensor_nodes[:input_node_num]
            self.output_tensors = tensor_nodes[input_node_num:]
        self.sess = tf.Session(graph=self.graph)

    def run_pb(self, input_list):
        if not len(input_list) == len(self.input_nodes):
            raise Exception(
                f"Model input error!.Expected {len(self.input_nodes)} input, got {len(input_list)}")
        input_dict = {}
        for val, name in zip(input_list, self.input_tensors):
            input_dict[name] = val
        output_values = self.sess.run(
            self.output_tensors, feed_dict=input_dict)
        return output_values

后来看了这个GitHub issue下的解答,找到了这个,顺利解决

for node in graph_def.node:
    if node.op == 'RefSwitch':
      node.op = 'Switch'
      for index in xrange(len(node.input)):
        if 'moving_' in node.input[index]:
          node.input[index] = node.input[index] + '/read'
    elif node.op == 'AssignSub':
      node.op = 'Sub'
      if 'use_locking' in node.attr: del node.attr['use_locking']
    elif node.op == 'AssignAdd':
      node.op = 'Add'
      if 'use_locking' in node.attr: del node.attr['use_locking']
    elif node.op == 'Assign':
      node.op = 'Identity'
      if 'use_locking' in node.attr: del node.attr['use_locking']
      if 'validate_shape' in node.attr: del node.attr['validate_shape']
      if len(node.input) == 2:
        # input0: ref: Should be from a Variable node. May be uninitialized.
        # input1: value: The value to be assigned to the variable.
        node.input[0] = node.input[1]
        del node.input[1]

修改后的代码如下:

    def _load_pb(self):
        return_elements = self.input_nodes + self.output_nodes
        input_node_num = len(self.input_nodes)
        with tf.gfile.GFile(self.model_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        self.graph = tf.Graph()
###########添加的部分#####
        for node in graph_def.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in range(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr: del node.attr['use_locking']
            elif node.op == 'AssignAdd':
                node.op = 'Add'
                if 'use_locking' in node.attr: del node.attr['use_locking']
            elif node.op == 'Assign':
                node.op = 'Identity'
                if 'use_locking' in node.attr: del node.attr['use_locking']
                if 'validate_shape' in node.attr: del node.attr['validate_shape']
                if len(node.input) == 2:
                    node.input[0] = node.input[1]
                    del node.input[1]
####################################################
        with self.graph.as_default() as g:
            # fix batch norm nodes
            tensor_nodes = tf.import_graph_def(
                graph_def, return_elements=return_elements)
            self.input_tensors = tensor_nodes[:input_node_num]
            self.output_tensors = tensor_nodes[input_node_num:]
        self.sess = tf.Session(graph=self.graph)

    def run_pb(self, input_list):
        if not len(input_list) == len(self.input_nodes):
            raise Exception(
                f"Model input error!.Expected {len(self.input_nodes)} input, got {len(input_list)}")
        input_dict = {}
        for val, name in zip(input_list, self.input_tensors):
            input_dict[name] = val
        output_values = self.sess.run(
            self.output_tensors, feed_dict=input_dict)
        return output_values

我的问题是对TensorFlow基础用法掌握不够,代码能力太次,简单的代码读了半天才懂,然后知道这段代码应该放到哪里。还是要好好打基础的呀。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值