通常我们在开发中根据不同任务需要不同的预训练模型,因此需要同时加载多个模型文件。但是同时加载多个TensorFlow预训练模型时,若还是采用加载单个模型文件一样的方式则会因图冲突而加载失败。主要是因为不同对象里面的不同sess使用了同一进程空间下的相同的默认图graph。 因此,我们需要为为每个类(实例)单独创建一个graph
g1 = tf.Graph() #为每个类(实例)单独创建一个graph
g2 = tf.Graph()
sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
#亲测,若你训练模型时指定了设备,如上一行代码,则你restore时也要加上,不然会出错。
sess1 = tf.Session(graph=g1, config=sess_config)
sess2 = tf.Session(graph=g2, config=sess_config)
#加载模型1,
with sess1.as_default():
with sess1.graph.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.import_meta_graph(model_path_1+'model.meta')
model_cpt = tf.train.get_checkpoint_state(model_path_1)
model_saver.restore(sess1, model_cpt.model_checkpoint_path)
graph = tf.get_default_graph()