TensorFlow 是一个用于机器学习和深度学习的开源库。在 TensorFlow 1.x 版本中,会话(Session)用于执行 TensorFlow 计算图(Graph)。在 TensorFlow 2.x 版本中,即时执行(Eager Execution)成了默认选项,但了解会话的使用仍然很有用。
在 TensorFlow 1.x 中使用 Session
创建会话
创建一个 TensorFlow Session 的基础用法如下:
import tensorflow as tf
# 创建一个计算图
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
# 初始化一个会话
sess = tf.Session()
# 执行计算图
print(sess.run(c))
Session.run() 详解
Session.run()
是执行 TensorFlow 计算图的主要方法。该方法接受多种类型的参数。
- fetches: 需要获取的 tensor 或者 Operation。可以是单个值,也可以是列表、字典等数据结构。
- feed_dict: 用于喂入数据的字典。
以下是一些例子:
# 单个 tensor
result = sess.run(c)
print(result)
# 多个 tensor
result = sess.run([a, b, c])
print(result)
# 使用 feed_dict 喂入数据
d = tf.placeholder(tf.float32)
e = a * d
result = sess.run(e, feed_dict={d: 2.0})
print(result)
Session 参数
在创建 tf.Session
的时候,你可以传入多个可选参数以控制会话的行为。
- target:如果你需要连接到远程的 TensorFlow 服务器或集群,可以设置这个参数。
- graph:用于设置计算图。默认是当前的默认图。
- config:通过 tf.ConfigProto 对象来设置底层的执行选项。
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
InteractiveSession
除了基础的 Session
,还有 InteractiveSession
,主要用于 Jupyter Notebook 或 Python shell 等交互式环境。
sess = tf.InteractiveSession()
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
print(c.eval())
sess.close()
使用 with
语句管理 Session
with tf.Session() as sess:
print(sess.run(c))
使用 Feed 和 Fetch
# 使用 Feed 传入数据
d = tf.placeholder(tf.float32)
e = a * d
with tf.Session() as sess:
print(sess.run(e, feed_dict={d: 2.0}))
# 使用 Fetch 获取多个数据
with tf.Session() as sess:
print(sess.run([a, b, c]))
在 TensorFlow 2.x 中使用 Session
在 2.x 版本中,你可以通过 tf.compat.v1
来使用 Session
。
import tensorflow as tf
# 禁用即时执行
tf.compat.v1.disable_eager_execution()
# 使用 tf.compat.v1 创建计算图和会话
a = tf.compat.v1.constant(5.0)
b = tf.compat.v1.constant(6.0)
c = a * b
sess = tf.compat.v1.Session()
print(sess.run(c))