TF模型导出KeyError: ‘ParallelInterleaveDataset‘ in Tf >= 1.13)

KeyError: 'ParallelInterleaveDataset' in Tf >= 1.13

问题描述

在TF 1.14.0版本中使用tf.train.import_meta_graph()时出现KeyError: ‘ParallelInterleaveDataset’

解决办法

出错的原因是这个op被更名为ExperimentalParallelInterleaveDataset,因此需要在import之前将’ParallelInterleaveDataset’改为’ParallelInterleaveDataset’。

代码如下 1

import os

from google.protobuf import text_format
import tensorflow as tf

# EDIT HERE
_CONVERSION = {
    'ParallelInterleaveDataset': 'ExperimentalParallelInterleaveDataset',
    'MapAndBatchDatasetV2': 'ExperimentalMapAndBatchDataset',
}

def read_meta_graph_file(filename):
    meta_graph_def = tf.MetaGraphDef()
    if not tf.gfile.Exists(filename):
        raise IOError("File %s does not exist." % filename)
    file_content = tf.gfile.GFile(filename, "rb").read()
    try:
        meta_graph_def.ParseFromString(file_content)
        return meta_graph_def
    except Exception:
        pass

    try:
        text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
    except text_format.ParseError as e:
        raise IOError("Cannot parse file %s: %s." % (filename, str(e)))

    return meta_graph_def


def _rename_op(s):
    for k, v in _CONVERSION.items():
        if k in s:
            return s.replace(k, v)
    return s


def convert(input_file, output_file):
    meta_graph_def = read_meta_graph_file(input_file)
    for node in meta_graph_def.graph_def.node:
        node.name = _rename_op(node.name)
        node.op = _rename_op(node.op)
        node.input[:] = [_rename_op(s) for s in node.input]
    tf.io.write_graph(meta_graph_def,
                      os.path.dirname(output_file),
                      os.path.basename(output_file),
                      as_text=False)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--meta_file", required=True)
    parser.add_argument("--output_file", required=True)
    parser.add_argument("--check", action="store_true")
    args = parser.parse_args()
    convert(args.meta_file, args.output_file)
    if args.check:
        _saver = tf.train.import_meta_graph(args.output_file)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值