Session()方法
首先我们需要创建一个Session对象.在不传参数的情况下,该Session的构造器将启动默认的图.之后我们可以通过Session对象的run(op)
来执行我们想要的操作。tensorflow的内核使用更加高效的C++作为后台,以支撑它的密集计算。tensorflow把前台(即python程序)与后台程序之间的连接称为"会话(Session)"。
Session
作为会话,主要功能是指定操作对象的执行环境,Session
类构造函数有3个可选参数。
target
(可选):指定连接的执行引擎,多用于分布式场景。graph
(可选):指定要在Session对象中参与计算的图(graph)。config
(可选):辅助配置Session对象所需的参数(限制CPU或GPU使用数目,设置优化参数以及设置日志选项等)。-
tf.Session.run(fetches,feed-dict=Noe,options=Node,run_metadata=None) 运行fetches中的操作节点并求其 tf.Session.close() 关闭会话 tf.Session.graph 返回加载该会话的图() tf.Session.as_default() 设置该对象为默认会话,并返回一个上下文管理器
run()方法
Session对象创建完毕,便可以使用它最重要的方法run()来启动所需要的数据流图进行计算。
run()方法有4个参数:
-
run(
-
fetches,
-
feed_dict=
None
-
options=
None,
-
run_metadata=
None
-
)
(1).fetches参数
- '取得之物',表示数据流图中能接收的任意数据流图元素,各类Op/Tensor对象。Op,run()将返回None;Tensor,rnu()将返回Numpy数组。
-
import tensorflow as tf
-
from collections import namedtuple
-
-
a = tf.constant([
10,
20])
-
b = tf.constant([
1.
0,
2.
0])
-
session = tf.Session()
-
-
v1 = session.run(a) #fetches参数为单个张量值,返回值为Numpy数组
-
print(v
1)
-
v2 = session.run([a, b]) #fetches参数为python类表,包括两个numpy的
1维矩阵
-
print(v
2)
-
v3 = session.run(tf.global_variables_initializer()) #fetches 为Op类型
-
print(v
3)
-
session.close()
-
[
10
20]
-
[
array([
10,
20],
dtype=int32),
array([
1
.,
2
.],
dtype=float32)]
-
None
(2). feed_dict参数
- 可选项,给数据流图提供运行时数据。
feed_dict
的数据结构为python中的字典,其元素为各种键值对。"key"为各种Tensor对象的句柄;"value"很广泛,但必须和“键”的类型相匹配,或能转换为同一类型。
-
import tensorflow as tf
-
-
a = tf.add(
1,
2)
-
b = tf.multiply(a,
2)
-
session = tf.Session()
-
v1 = session.run(b)
-
print(v
1)
-
-
replace_dict = {a:
20}
-
v2 = session.run(b, feed_dict = replace_dict)
-
print(v
2)
-
6
-
40