tensorflow 中的 graph和Session
写个笔记方便自己以后不记得了回头查看
1. 哪个graph? 哪个tensor /operation?
问题
使用tensorflow的时候就很奇怪,东建一个图,西建一个operation, 到底这个operation在哪个图里, 我现在玩的到底是哪个图, 这个图里又有哪些operation
总结
- 计算图只是定义计算计算流程的形式框架, 并不参与实际的计算
- 使用 graph = tf.Graph() 可以新建一个计算图实例
- 你要玩的张量或者operation 全都是在default默认图下的, 所以你要在哪个图下添加操作或者张量就得把哪张图当成默认图 .as_default() 的上下文管理器 with xxx.as_default(): 期间去操作
- 如果你没有用上面2中的语句来生成一个计算图实例,当你创建一个张量时,tensorflow也会默认给创建一张默认图,并把你加的东西放这张图上
- tf.Session() 叫做会话,通常用来链接前端操作与后端实际计算接口, 意思基本上就是Sess = tf.Session(graph=graph) 的意思就是Sess这个会话专门用来运行graph这个计算图, 如果Sess=tf.Session(), 那么默认计算默认图。
- 一个图可以在多个sess中运行,一个sess也能运行多个图
简单的代码示例
例子1
姿势不对,张量永远写在默认图上
,你的新建图只是个路人甲
graph1 = tf.Graph() # 新建一个图但是这graph1并不是默认图
a1 = tf.constant(2.0, tf.float32) # 新建一个张量,这个张量放在默认图上
b1 = tf.constant(1.0, tf.float32)
graph2 = a1.graph # graph2:存放张量的图
graph3 = tf.get_default_graph() # 默认图
print(graph1)
print(graph2)
print(graph3)
print('-----------------------------')
print(graph1 is graph2) # False 第一个建的图不是张量放的图
print(graph2 is graph3) # True 放张量的图一定是默认图
======================================
output:
<tensorflow.python.framework.ops.Graph object at 0x00000180493BCEF0>
<tensorflow.python.framework.ops.Graph object at 0x00000180493A58D0>
<tensorflow.python.framework.ops.Graph object at 0x00000180493A58D0>
-----------------------------
False
True
例子2
这是错误的姿势
:把张量放在指定图上
graph1 = tf.Graph().as_default() # 注意这里是错的:想的是这下graph1是默认图了
a1 = tf.constant(2.0, tf.float32) # 建了张量, 本以为它会写在已经为默认图的graph1上了
b1 = tf.constant(1.0, tf.float32)
graph2 = a1.graph # 获取a1 所在图
graph3 = tf.get_default_graph() # 获取默认图
print(graph1 is graph2) # 张量并没有写在graph1上
print(graph2 is graph3) # True: 张量还是写在了默认图上
print(graph1 is graph3) # False: graph1 还是 graph1 并没有变 成默认图
# 这是有道理的,不然默认图是啥,在哪,就再也找不回了
======================================
output:
False
True
False
这才是在指定图上写张量的正确姿势
graph1 = tf.Graph()
with graph1.as_default(): # 把graph1 当成默认图
a2 = tf.constant(2.0, tf.float32) # 在当成默认图的期间添加张量
b2 = tf.constant(1.0, tf.float32)
graph2 = a2.graph
graph3 = tf.get_default_graph()
print("-----------------------------")
print(graph1 is graph2) # True:张量写在了graph1上
print(graph1 is graph3) # False:但是并不代表graph1就是默认图了
print(graph2 is graph3) # False:没写在默认图上因为在写张量的时候
# graph1 被当成了默认图
======================================
output:
True
False
False
2. 哪个sess? 运行哪个图?
问题
很多时候自己老是搞不清楚哪个sess运行的是哪个图。
总结
- 啥都不加运行默认图
- 加了啥就是运行啥图
- sess只是一个运行的会话,把要运行的命令传达给后端的计算系统
- 一个图可以在多个sess中运行,一个sess也能运行多个图
- TensorFlow会自动生成一个默认的计算图,如果没有特殊指定,运算会自动加入这个计算图中。TensorFlow中的会话也有类似的机制,但是TensorFlow不会自动生成默认的会话,而是需要手动指定。
- tf.Session()创建一个会话,当上下文管理器退出时会话关闭和资源释放自动完成. tf.Session().as_default()创建一个默认会话,当上下文管理器退出时会话没有关闭,还可以通过调用会话进行run()和eval()操作
简单的代码示例
例子1: tf.Session() 与 tf.Session().as_default()的区别
- tf.Session()
a = tf.constant(1.0 ,tf.float32)
with tf.Session() as sess:
print(sess.run(a))
print(sess)
sess.run(a) # 这里出现报错:使用了一个已经关闭的会话
# 其实也很好理解因为with的生命周期结束里面的也莫得了
======================================
output:
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB07A64A8>
RuntimeError: Attempted to use a closed Session.
- tf.Session().as_default()
b = tf.constant(1.0, tf.float32)
with tf.Session().as_default() as sess1:
print(sess1.run(b))
print(sess1)
print(sess1.run(b)) # 这里就可以用因为as_default()会让sess的线程保留
# 直到你手动结束为止sess1.close()或者最外层主线程的进程结束为止
======================================
output:
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB23BFF98>
1.0
- session().as_default() 持续多久? 直到你要手动结束为止!
def test_life_time():
b = tf.constant(1.0, tf.float32)
with tf.Session().as_default() as sess1:
print(sess1.run(b))
print(sess1)
print(sess1.run(b))
return sess1
sess1 = test_life_time() # 真的是可以反复使用
print(sess1)
print(sess1.run(b))
======================================
output:
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB14FEDD8>
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB14FEDD8>
1.0
但是你仍然需要传递出来,不然进程结束它还是莫得了
def test_life_time():
b = tf.constant(1.0, tf.float32)
with tf.Session().as_default() as sess1:
print(sess1.run(b))
print(sess1)
print(sess1.run(b))
test_life_time()
print(sess1) # 这里报错:因为没return还是受到了进程结束的限制
======================================
output:
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB14FE550>
1.0
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-26-22acfee1bd9e> in <module>
14
15 test_life_time()
---> 16 print(sess3)
17 # print(sess1.run(b))
NameError: name 'sess3' is not defined
例子2:你不去建session 人家是不会默认给你一个session的
所以得先建, 然后全都可以设置为as_default()
a = tf.constant(1, tf.float32)
b = tf.constant(2, tf.float32)
g1 = tf.Graph()
with g1.as_default():
a0 = tf.constant(3,tf.float32)
a1 = tf.constant(4, tf.float32)
g2 = tf.get_default_graph()
yet_run_sess = tf.get_default_session()
print(sess)
sess1 = tf.Session()
ret_a = sess1.run(a)
print(sess1)
print(ret_a)
sess1.as_default()
print(tf.get_default_session())
after_run_sess = tf.get_default_session()
print(after_run_sess)
======================================
output:
None
<tensorflow.python.client.session.Session object at 0x0000020FB151D358>
1.0
None
None
例子3: 多图多session的一般使用模板(源码是搬运的)
# -*- coding: utf-8 -*-)
import tensorflow as tf
# 在系统默认计算图上创建张量和操作
a=tf.constant([1.0,2.0])
b=tf.constant([2.0,1.0])
result = a+b
# 定义两个计算图
g1=tf.Graph()
g2=tf.Graph()
# 在计算图g1中定义张量和操作
with g1.as_default():
a = tf.constant([1.0, 1.0])
b = tf.constant([1.0, 1.0])
result1 = a + b
with g2.as_default():
a = tf.constant([2.0, 2.0])
b = tf.constant([2.0, 2.0])
result2 = a + b
# 在g1计算图上创建会话
with tf.Session(graph=g1) as sess:
out = sess.run(result1)
print 'with graph g1, result: {0}'.format(out)
with tf.Session(graph=g2) as sess:
out = sess.run(result2)
print 'with graph g2, result: {0}'.format(out)
# 在默认计算图上创建会话
with tf.Session(graph=tf.get_default_graph()) as sess:
out = sess.run(result)
print 'with graph default, result: {0}'.format(out)
print g1.version # 返回计算图中操作的个数
g1 = tf.Graph() # 加载到Session 1的graph
g2 = tf.Graph() # 加载到Session 2的graph
sess1 = tf.Session(graph=g1) # Session1
sess2 = tf.Session(graph=g2) # Session2
# 加载第一个模型
with sess1.as_default():
with g1.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.Saver(tf.global_variables())
model_ckpt = tf.train.get_checkpoint_state(“model1/save/path”)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)
# 加载第二个模型
with sess2.as_default(): # 1
with g2.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.Saver(tf.global_variables())
model_ckpt = tf.train.get_checkpoint_state(“model2/save/path”)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)
...
# 使用的时候
with sess1.as_default():
with sess1.graph.as_default(): # 2
...
with sess2.as_default():
with sess2.graph.as_default():
...
# 关闭sess
sess1.close()
sess2.close()
例子4: 各种图和各种session混合乱用
一句话就是图和序列化图分开存, import graph的时候要加as_default()
class Model:
def __init__(self, model_file):
self.graph = tf.Graph()
self.graph_def = tf.GraphDef()
with gfile.FastGFile(model_file, 'rb') as f:
self.graph_def.ParseFromString(f.read())
with self.graph.as_default():
tf.import_graph_def(self.graph_def, name='')
self.sess = tf.Session(graph=self.graph, config=config)
def predict(self, images: list):
output_node = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[-1].name)
input_x = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[0].name)
w = input_x.shape[1]
h = input_x.shape[2]
data = []
for img in images:
img = img.resize((w, h))
img = np.array(img).astype(float)
data.append(img)
feed = {input_x: data}
out = self.sess.run(output_node, feed)
return out
在这里插入代码片
3. 参考文章
http://blog.sina.com.cn/s/blog_628cc2b70102yonj.html
https://zhuanlan.zhihu.com/p/31308381
https://www.cnblogs.com/ywheunji/p/11390219.html