tensorflow 中的 graph 和 Session


写个笔记方便自己以后不记得了回头查看


1. 哪个graph? 哪个tensor /operation?

问题

使用tensorflow的时候就很奇怪,东建一个图,西建一个operation, 到底这个operation在哪个图里, 我现在玩的到底是哪个图, 这个图里又有哪些operation

总结

  1. 计算图只是定义计算计算流程的形式框架, 并不参与实际的计算
  2. 使用 graph = tf.Graph() 可以新建一个计算图实例
  3. 你要玩的张量或者operation 全都是在default默认图下的, 所以你要在哪个图下添加操作或者张量就得把哪张图当成默认图 .as_default() 的上下文管理器 with xxx.as_default(): 期间去操作
  4. 如果你没有用上面2中的语句来生成一个计算图实例,当你创建一个张量时,tensorflow也会默认给创建一张默认图,并把你加的东西放这张图上
  5. tf.Session() 叫做会话,通常用来链接前端操作与后端实际计算接口, 意思基本上就是Sess = tf.Session(graph=graph) 的意思就是Sess这个会话专门用来运行graph这个计算图, 如果Sess=tf.Session(), 那么默认计算默认图。
  6. 一个图可以在多个sess中运行,一个sess也能运行多个图

简单的代码示例

例子1

姿势不对,张量永远写在默认图上,你的新建图只是个路人甲

graph1 = tf.Graph()                # 新建一个图但是这graph1并不是默认图
a1 = tf.constant(2.0, tf.float32)  # 新建一个张量,这个张量放在默认图上
b1 = tf.constant(1.0, tf.float32)

graph2 = a1.graph                  # graph2:存放张量的图
graph3 = tf.get_default_graph()    # 默认图

print(graph1)
print(graph2)
print(graph3)
print('-----------------------------')
print(graph1 is graph2)            # False 第一个建的图不是张量放的图
print(graph2 is graph3)            # True  放张量的图一定是默认图

======================================
output:
<tensorflow.python.framework.ops.Graph object at 0x00000180493BCEF0>
<tensorflow.python.framework.ops.Graph object at 0x00000180493A58D0>
<tensorflow.python.framework.ops.Graph object at 0x00000180493A58D0>
-----------------------------
False
True

例子2

这是错误的姿势:把张量放在指定图上

graph1 = tf.Graph().as_default()    # 注意这里是错的:想的是这下graph1是默认图了
a1 = tf.constant(2.0, tf.float32)   # 建了张量, 本以为它会写在已经为默认图的graph1上了
b1 = tf.constant(1.0, tf.float32)   

graph2 = a1.graph                   # 获取a1 所在图
graph3 = tf.get_default_graph()     # 获取默认图

print(graph1 is graph2)             # 张量并没有写在graph1上
print(graph2 is graph3)             # True: 张量还是写在了默认图上   
print(graph1 is graph3)             # False: graph1 还是 graph1 并没有变 成默认图
                                    # 这是有道理的,不然默认图是啥,在哪,就再也找不回了

======================================
output:
False
True
False

这才是在指定图上写张量的正确姿势

graph1 = tf.Graph()
with graph1.as_default():              # 把graph1 当成默认图
    a2 = tf.constant(2.0, tf.float32)  # 在当成默认图的期间添加张量
    b2 = tf.constant(1.0, tf.float32)
graph2 = a2.graph
graph3 = tf.get_default_graph()

print("-----------------------------")
print(graph1 is graph2)                # True:张量写在了graph1上
print(graph1 is graph3)                # False:但是并不代表graph1就是默认图了
print(graph2 is graph3)                # False:没写在默认图上因为在写张量的时候
                                       # graph1 被当成了默认图
 ======================================
output:
True
False
False

2. 哪个sess? 运行哪个图?

问题

很多时候自己老是搞不清楚哪个sess运行的是哪个图。

总结

  1. 啥都不加运行默认图
  2. 加了啥就是运行啥图
  3. sess只是一个运行的会话,把要运行的命令传达给后端的计算系统
  4. 一个图可以在多个sess中运行,一个sess也能运行多个图
  5. TensorFlow会自动生成一个默认的计算图,如果没有特殊指定,运算会自动加入这个计算图中。TensorFlow中的会话也有类似的机制,但是TensorFlow不会自动生成默认的会话,而是需要手动指定。
  6. tf.Session()创建一个会话,当上下文管理器退出时会话关闭和资源释放自动完成. tf.Session().as_default()创建一个默认会话,当上下文管理器退出时会话没有关闭,还可以通过调用会话进行run()和eval()操作

简单的代码示例

例子1: tf.Session() 与 tf.Session().as_default()的区别

  1. tf.Session()
a = tf.constant(1.0 ,tf.float32)   
with tf.Session() as sess:
    print(sess.run(a))   
    print(sess)
sess.run(a)                 # 这里出现报错:使用了一个已经关闭的会话
                            # 其实也很好理解因为with的生命周期结束里面的也莫得了

======================================
output:
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB07A64A8>
RuntimeError: Attempted to use a closed Session.
  1. tf.Session().as_default()
b = tf.constant(1.0, tf.float32)
with tf.Session().as_default() as sess1:
    print(sess1.run(b))
    print(sess1)
print(sess1.run(b))         # 这里就可以用因为as_default()会让sess的线程保留
                            # 直到你手动结束为止sess1.close()或者最外层主线程的进程结束为止
                            
======================================
output:
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB23BFF98>
1.0
  1. session().as_default() 持续多久? 直到你要手动结束为止!
def test_life_time():
    b = tf.constant(1.0, tf.float32)
    with tf.Session().as_default() as sess1:
        print(sess1.run(b))
        print(sess1)
    print(sess1.run(b))
    return sess1
    
sess1 = test_life_time()        # 真的是可以反复使用
print(sess1)
print(sess1.run(b))

======================================
output:
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB14FEDD8>
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB14FEDD8>
1.0

但是你仍然需要传递出来,不然进程结束它还是莫得了

def test_life_time():
    b = tf.constant(1.0, tf.float32)
    with tf.Session().as_default() as sess1:
        print(sess1.run(b))
        print(sess1)
    print(sess1.run(b))
test_life_time()
print(sess1)               # 这里报错:因为没return还是受到了进程结束的限制


======================================
output:
1.0
<tensorflow.python.client.session.Session object at 0x0000020FB14FE550>
1.0
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-26-22acfee1bd9e> in <module>
     14 
     15 test_life_time()
---> 16 print(sess3)
     17 # print(sess1.run(b))

NameError: name 'sess3' is not defined

例子2:你不去建session 人家是不会默认给你一个session的

所以得先建, 然后全都可以设置为as_default()

a = tf.constant(1, tf.float32)
b = tf.constant(2, tf.float32)

g1 = tf.Graph()
with g1.as_default():
    a0 = tf.constant(3,tf.float32)
    a1 = tf.constant(4, tf.float32)
g2 = tf.get_default_graph()

yet_run_sess = tf.get_default_session()
print(sess)

sess1 = tf.Session()
ret_a = sess1.run(a)
print(sess1)
print(ret_a)

sess1.as_default()
print(tf.get_default_session())

after_run_sess = tf.get_default_session()
print(after_run_sess)

======================================
output:
None
<tensorflow.python.client.session.Session object at 0x0000020FB151D358>
1.0
None
None

例子3: 多图多session的一般使用模板(源码是搬运的)

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
# 在系统默认计算图上创建张量和操作
a=tf.constant([1.0,2.0])
b=tf.constant([2.0,1.0])
result = a+b
 
# 定义两个计算图
g1=tf.Graph()
g2=tf.Graph()
 
# 在计算图g1中定义张量和操作
with g1.as_default():
    a = tf.constant([1.0, 1.0])
    b = tf.constant([1.0, 1.0])
    result1 = a + b
 
with g2.as_default():
    a = tf.constant([2.0, 2.0])
    b = tf.constant([2.0, 2.0])
    result2 = a + b
 
 
# 在g1计算图上创建会话
with tf.Session(graph=g1) as sess:
    out = sess.run(result1)
    print 'with graph g1, result: {0}'.format(out)
 
with tf.Session(graph=g2) as sess:
    out = sess.run(result2)
    print 'with graph g2, result: {0}'.format(out)
 
# 在默认计算图上创建会话
with tf.Session(graph=tf.get_default_graph()) as sess:
    out = sess.run(result)
    print 'with graph default, result: {0}'.format(out)
 
print g1.version  # 返回计算图中操作的个数
g1 = tf.Graph() # 加载到Session 1的graph
g2 = tf.Graph() # 加载到Session 2的graph

sess1 = tf.Session(graph=g1) # Session1
sess2 = tf.Session(graph=g2) # Session2

# 加载第一个模型
with sess1.as_default(): 
    with g1.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(“model1/save/path”)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)
# 加载第二个模型
with sess2.as_default():  # 1
    with g2.as_default():  
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(“model2/save/path”)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

...

# 使用的时候
with sess1.as_default():
    with sess1.graph.as_default():  # 2
        ...

with sess2.as_default():
    with sess2.graph.as_default():
        ...

# 关闭sess
sess1.close()
sess2.close()

例子4: 各种图和各种session混合乱用

一句话就是图和序列化图分开存, import graph的时候要加as_default()

class Model:
    def __init__(self, model_file):
        self.graph = tf.Graph()
        self.graph_def = tf.GraphDef()
        with gfile.FastGFile(model_file, 'rb') as f:
            self.graph_def.ParseFromString(f.read())
        with self.graph.as_default():
            tf.import_graph_def(self.graph_def, name='')

        self.sess = tf.Session(graph=self.graph, config=config)

    def predict(self, images: list):
        output_node = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[-1].name)
        input_x = self.sess.graph.get_tensor_by_name('%s:0' % self.graph_def.node[0].name)

        w = input_x.shape[1]
        h = input_x.shape[2]
        data = []

        for img in images:
            img = img.resize((w, h))
            img = np.array(img).astype(float)
            data.append(img)

        feed = {input_x: data}
        out = self.sess.run(output_node, feed)
        return out
在这里插入代码片

3. 参考文章

http://blog.sina.com.cn/s/blog_628cc2b70102yonj.html

https://zhuanlan.zhihu.com/p/31308381

https://www.cnblogs.com/ywheunji/p/11390219.html

TensorFlow函数:tf.Session()和tf.Session().as_default()的区别

tf.get_default_session()

tensorflow 计算图和张量的使用

Tensorflow同时加载使用多个模型

TensorFlow 的 session 使用

官方中文文档

https://blog.youkuaiyun.com/dcrmg/article/details/79028032

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值