问题描述
在TF 1.14.0版本中使用tf.train.import_meta_graph()时出现KeyError: ‘ParallelInterleaveDataset’
解决办法
出错的原因是这个op被更名为ExperimentalParallelInterleaveDataset,因此需要在import之前将’ParallelInterleaveDataset’改为’ParallelInterleaveDataset’。
代码如下 1:
import os
from google.protobuf import text_format
import tensorflow as tf
# EDIT HERE
_CONVERSION = {
'ParallelInterleaveDataset': 'ExperimentalParallelInterleaveDataset',
'MapAndBatchDatasetV2': 'ExperimentalMapAndBatchDataset',
}
def read_meta_graph_file(filename):
meta_graph_def = tf.MetaGraphDef()
if not tf.gfile.Exists(filename):
raise IOError("File %s does not exist." % filename)
file_content = tf.gfile.GFile(filename, "rb").read()
try:
meta_graph_def.ParseFromString(file_content)
return meta_graph_def
except Exception:
pass
try:
text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
except text_format.ParseError as e:
raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
return meta_graph_def
def _rename_op(s):
for k, v in _CONVERSION.items():
if k in s:
return s.replace(k, v)
return s
def convert(input_file, output_file):
meta_graph_def = read_meta_graph_file(input_file)
for node in meta_graph_def.graph_def.node:
node.name = _rename_op(node.name)
node.op = _rename_op(node.op)
node.input[:] = [_rename_op(s) for s in node.input]
tf.io.write_graph(meta_graph_def,
os.path.dirname(output_file),
os.path.basename(output_file),
as_text=False)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--meta_file", required=True)
parser.add_argument("--output_file", required=True)
parser.add_argument("--check", action="store_true")
args = parser.parse_args()
convert(args.meta_file, args.output_file)
if args.check:
_saver = tf.train.import_meta_graph(args.output_file)