tf.Session

参考 TensorFlow函数:tf.Session()和tf.Session().as_default()的区别 - 云+社区 - 腾讯云

目录

一、概述

二、函数列表

1、__init__

2、__enter__

3、__exit__

4、as_default

5、close

6、list_devices

7、make_callable

8、partial_run

9、partial_run_setup

10、reset

11、run



一、概述

一个运行TensorFlow操作的类。会话对象封装了执行操作对象和计算张量对象的环境。

例:

# Build a graph.
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b

# Launch the graph in a session.
sess = tf.Session()

# Evaluate the tensor `c`.
print(sess.run(c))

会话可能拥有资源,比如tf.variable,tf.QueueBase, tf.ReaderBase。当不再需要这些资源时,释放它们是很重要的。为此,可以调用tf.Session。关闭会话上的方法,或将会话用作上下文管理器。下面两个例子是等价的:

# Using the `close()` method.
sess = tf.Session()
sess.run(...)
sess.close()

# Using the context manager.
with tf.Session() as sess:
  sess.run(...)

ConfigProto协议缓冲区公开会话的各种配置选项。例如,要创建一个使用设备放置软约束的会话,并记录结果的放置决策,创建一个会话如下:

# Launch the graph in a session that allows soft device placement and
# logs the placement decisions.
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=True))

二、函数列表

1、__init__

__init__(
    target='',
    graph=None,
    config=None
)

 创建一个新的TensorFlow会话。如果在构造会话时没有指定图形参数,则会话中将启动缺省图。如果在同一过程中使用多个图(使用tf.Graph()创建),则必须为每个图使用不同的会话,但是每个图可以在多个会话中使用。在这种情况下,将要显式启动的图形传递给会话构造函数通常更清楚。

参数:

  • target:  (可选)。要连接到的执行引擎。默认使用进程内引擎。有关更多示例,请参见分布式TensorFlow。
  • graph:  (可选)。将要启动的图表(如上所述)。
  • config:  (可选)带有会话配置选项的ConfigProto协议缓冲区。

2、__enter__

__enter__()

3、__exit__

__exit__(
    exec_type,
    exec_value,
    exec_tb
)

4、as_default

as_default()

返回使此对象成为默认会话的上下文管理器。使用with关键字指定对tf.Operation.run或tf.Tensor的调用。eval应该在这个会话中执行。

c = tf.constant(..)
sess = tf.Session()

with sess.as_default():
  assert tf.get_default_session() is sess
  print(c.eval())

要获取当前默认会话,请使用tf.get_default_session。注意:当你退出上下文时,as_default上下文管理器不会关闭会话,你必须显式地关闭会话。 

c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
  print(c.eval())
# ...
with sess.as_default():
  print(c.eval())

sess.close()

或者,你可以使用with tf.Session():创建一个在退出上下文时自动关闭的会话,包括在引发未捕获异常时。 

注意:默认会话是当前线程的属性。如果您创建了一个新线程,并且希望在该线程中使用默认会话,则必须在该线程的函数中显式地添加一个带有sess() .as_default():的会话。注意:使用ssh .as_default():块输入不会影响当前默认图。如果你正在使用多个图形,那么sess。图与tf值不同。get_default_graph,您必须显式地输入一个带有sess.graph.as_default():块的参数来执行sess。绘制默认图形。

返回值:

  • 使用此会话作为默认会话的上下文管理器。

5、close

close()

关闭会话。调用此方法释放与会话关联的所有资源。

异常:

  • tf.errors.OpError: Or one of its subclasses if an error occurs while closing the TensorFlow session.

6、list_devices

list_devices()

列表中的每个元素都具有以下属性:

devices = sess.list_devices()
for d in devices:
  print(d.name)

列表中的每个元素都具有以下属性:

  • name:一个带有设备全名的字符串。例如:/job:worker/ copy:0/task:3/device:CPU:0
  • device_type:设备的类型(例如CPU、GPU、TPU)
  • memory_limit:设备上可用的最大内存量。

可能产生的异常:

  • tf.errors.OpError: If it encounters an error (e.g. session is in an invalid state, or network errors occur).

返回值:

  • 会话中的设备列表。

7、make_callable

make_callable(
    fetches,
    feed_list=None,
    accept_options=False
)

返回一个运行特定步骤的Python可调用函数。返回的可调用函数将接受len(feed_list)参数,其类型必须与feed_list的各个元素的提要值兼容。例如,如果feed_list的元素i是tf.Tensor,返回的可调用的第i个参数必须是一个numpy ndarray(或可转换为ndarray的东西),它具有匹配的元素类型和形状。返回的可调用函数将具有与tf.Session.run(fetches,…)相同的返回类型。例如,如果fetches是tf。张量,可调用的将返回一个numpy ndarray;如果fetches是tf。操作,它将返回None。

参数:

  • fetches:  要获取的值或值列表。有关允许获取类型的详细信息,请参见tf.Session.run。
  • feed_list:  (可选)。feed_dict键的列表。有关允许的提要键类型的详细信息,请参见tf.Session.run。
  • accept_options:(可选)。如果为真,返回的Callable将能够接受tf.RunOptions和tf.RunMetadata分别作为可选的关键字参数选项和run_metadata,具有与tf.Session.run相同的语法和语义,这对于某些用例(分析和调试)是有用的,但是会导致可调用程序性能的显著下降。默认值:False。

返回值:

  • 调用时将执行feed_list定义的步骤并在此会话中获取的函数。

可能产生的异常:

  • TypeError: If fetches or feed_list cannot be interpreted as arguments to tf.Session.run.

8、partial_run

partial_run(
    handle,
    fetches,
    feed_dict=None
)

使用更多的提要和获取继续执行。这是实验性的,可能会发生变化。 要使用部分执行,用户首先调用partial_run_setup(),然后调用partial_run()序列。partial_run_setup指定将在后续partial_run调用中使用的提要和获取列表。可选的feed_dict参数允许调用者覆盖图中张量的值。有关更多信息,请参见run()。

下面是一个简单的例子:

a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.multiply(r1, c)

h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
res = sess.partial_run(h, r2, feed_dict={c: res})

参数:

  • handle:  用于一系列部分运行的句柄。
  • fetches:  单个图形元素、一组图形元素或一个字典,其值是图形元素或图形元素列表(请参阅运行文档)。
  • feed_dict:将图形元素映射到值的字典(如上所述)。

返回值:

  • 如果fetches是单个图形元素,则使用单个值;如果fetches是列表,则使用值列表;如果fetches是字典,则使用与之相同的键的字典(有关运行,请参阅文档)。

可能产生的异常:

  • tf.errors.OpError: Or one of its subclasses on error.

9、partial_run_setup

partial_run_setup(
    fetches,
    feeds=None
)

为部分运行设置一个包含提要和获取的图。这是实验性的,可能会发生变化。注意,与run相反,提要只指定图元素。张量将由后续的partial_run调用提供。

参数:

  • fetches:  单个图元素,或一组图元素。
  • feeds:  单个图元素,或图元素列表。

返回值:

  • 用于部分运行的句柄。

可能产生的异常:

  • RuntimeError: If this Session is in an invalid state (e.g. has been closed).
  • TypeError: If fetches or feed_dict keys are of an inappropriate type.
  • tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens.

10、reset

@staticmethod
reset(
    target,
    containers=None,
    config=None
)

在目标上重置资源容器,并关闭所有连接的会话。资源容器分布在与目标相同的集群中的所有worker上。当重置目标上的资源容器时,将清除与该容器关联的资源。特别是,容器中的所有变量都将成为未定义的:它们将丢失它们的值和形状。

注意:

  1. reset()目前仅为分布式会话实现。
  2. 按目标命名的关于主机的任何会话都将关闭。

如果没有提供资源容器,则重置所有容器。

参数:

  • target:  要连接到的执行引擎。
  • containers:  资源容器名称字符串的列表,如果要重置所有容器,则为None。
  • config:  (可选)带有配置选项的协议缓冲区。

可能产生的异常:

  • tf.errors.OpError: Or one of its subclasses if an error occurs while resetting containers.

11、run

run(
    fetches,
    feed_dict=None,
    options=None,
    run_metadata=None
)

在读取中运行操作并计算张量。该方法运行TensorFlow计算的一个“步骤”,通过运行必要的图的段来执行每一个操作,并在fetches中计算每个张量,用feed_dict中的值替换相应的输入值。fetches参数可以是一个单独的图形元素,也可以是一个任意嵌套的列表、元组、namedtuple、dict或OrderedDict,它的叶子中包含图形元素。图形元素可以是以下类型之一:

  • 一个tf.Operation。对应的获取值将为None。
  • tf.Tensor。相应的获取值将是一个包含该张量值的numpy ndarray。
  • tf.SparseTensor。对应的获取值将是tf.compat.v1.SparseTensorValue包含稀疏张量的值。
  • 一个get_tensor_handle操作符。相应的获取值将是一个包含该张量句柄的numpy ndarray。
  • 一个字符串,它是图中张量或运算的名称。

run()返回的值具有与fetches参数相同的形状,其中叶子被TensorFlow返回的相应值替换。

例:

   a = tf.constant([10, 20])
   b = tf.constant([1.0, 2.0])
   # 'fetches' can be a singleton
   v = session.run(a)
   # v is the numpy array [10, 20]
   # 'fetches' can be a list.
   v = session.run([a, b])
   # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
   # 1-D array [1.0, 2.0]
   # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
   MyData = collections.namedtuple('MyData', ['a', 'b'])
   v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
   # v is a dict with
   # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
   # 'b' (the numpy array [1.0, 2.0])
   # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
   # [10, 20].

可选的feed_dict参数允许调用者覆盖图中张量的值。feed_dict中的每个键都可以是以下类型之一:

  • 如果键是tf.Tensor,其值可以是Python标量、字符串、列表或numpy ndarray,可以转换为与该张量相同的dtype。此外,如果键是tf.compat.v1.placeholder,将检查值的形状是否与占位符兼容。
  • 如果键是tf.Tensorsparse,这个值应该是tf.SparseTensorValue。
  • 如果键是张量或稀疏张量的嵌套元组,则该值应该是嵌套元组,其结构与上面映射到其对应值的结构相同。

feed_dict中的每个值必须转换为对应键的dtype的numpy数组。可选参数options期望RunOptions原型。这些选项允许控制此特定步骤的行为(例如打开跟踪)。可选的run_metadata参数需要一个[RunMetadata]原型。在适当的时候,这个步骤的非张量输出将被收集到这里。例如,当用户在options中打开跟踪选项时,所分析的信息将被收集到这个参数中并传递回去。

参数:

  • fetches: 单个图元素、图元素列表或字典,其值是图元素或图元素列表(如上所述)
  • feed_dict: 将图形元素映射到值的字典(如上所述)。
  • options: [runo]协议缓冲区
  • run_metadata: 一个[RunMetadata]协议缓冲区

返回值:

如果fetches是单个图形元素,则使用单个值;如果fetches是列表,则使用值列表;如果fetches是字典,则使用与之相同的键的字典(如上所述)。未定义在调用中计算获取操作的顺序。

可能产生的异常:

  • RuntimeError: If this Session is in an invalid state (e.g. has been closed).
  • TypeError: If fetches or feed_dict keys are of an inappropriate type.
  • ValueError: If fetches or feed_dict keys are invalid or refer to a Tensor that doesn't exist.

例:

# tf.Session().as_default():创建一个默认会话

# 那么问题来了,会话和默认会话有什么区别呢?TensorFlow会自动生成一个默认的计算图,如果没有特殊指定,
# 运算会自动加入这个计算图中。TensorFlow中的会话也有类似的机制,但是TensorFlow不会自动生成默认的会
# 话,而是需要手动指定。

# tf.Session()创建一个会话,当上下文管理器退出时会话关闭和资源释放自动完成。

# tf.Session().as_default()创建一个默认会话,当上下文管理器退出时会话没有关闭,还可以通过调用会话
# 进行run()和eval()操作,代码示例如下:

import tensorflow as tf
a = tf.constant(1.0)
b = tf.constant(2.0)
with tf.Session() as sess:
   print(a.eval())
print(b.eval(session=sess))

Output:
-----------------------------------------------------------------------------------
Traceback (most recent call last):
  File "D:/tensorflow_learning/test.py", line 6, in <module>
    print(b.eval(session=sess))
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 711, in eval
    return _eval_using_default_session(self, feed_dict, self.graph, session)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 5155, in _eval_using_default_session
1.0
    return session.run(tensors, feed_dict)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 887, in run
    run_metadata_ptr)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1033, in _run
    raise RuntimeError('Attempted to use a closed Session.')
RuntimeError: Attempted to use a closed Session.
解释:
在打印张量b的值时报错,tf.Session()运行完成后会默认关闭会话。
------------------------------------------------------------------------------------




# 将代码改为:
import tensorflow as tf
a = tf.constant(1.0)
b = tf.constant(2.0)
with tf.Session().as_default() as sess:
   print(a.eval())
print(b.eval(session=sess))

Output:
-----------------------------------------
1.0
2.0
解释:
tf.Session().as_default()默认不会关闭会话
-----------------------------------------




# 也可以显式的调用close()关闭会话
import tensorflow as tf
a = tf.constant(1.0)
b = tf.constant(2.0)
with tf.Session().as_default() as sess:
   print(a.eval())  
   sess.close()
print(b.eval(session=sess))

Output:
----------------------------------------------------------------------------------
1.0
Traceback (most recent call last):
  File "D:/tensorflow_learning/test.py", line 7, in <module>
    print(b.eval(session=sess))
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 711, in eval
    return _eval_using_default_session(self, feed_dict, self.graph, session)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 5155, in _eval_using_default_session
    return session.run(tensors, feed_dict)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 887, in run
    run_metadata_ptr)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 1033, in _run
    raise RuntimeError('Attempted to use a closed Session.')
RuntimeError: Attempted to use a closed Session.

解释:
tf.Session().as_default()需要手动调用类中的close()关闭会话。
-----------------------------------------------------------------------------------

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Wanderer001

ROIAlign原理

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值