Tensorflow小技巧整理:tf.cond()的小应用
tf.cond() 方法
tf.cond() 的作用类似于平常所使用的 if…else… 语句,但是在 tensorflow 中,所有节点是用图来保存的。而在图中传输的这些 tensor 数据流在我们使用 Session().run() 之前又是无法确定其数值的,所以这个时候传统的条件判断语句就无法使用。
比如我们想判断 a 和 b 是否相等:
a = tf.constant(3)
b = tf.constant(3)
# a 和 b 直接打印都是 <tf.Tensor '...' shape=() dtype=int32>
如果我们直接使用 ’ a == b ’ 判断,得到的是 ‘False’,而如果使用 ‘tf.equal(a,b)’ 来判断,返回的是一个新的张量:
is_equal = tf.equal(a,b)
# is_equal 是 <tf.Tensor '...' shape=() dtype=bool>
# 如果这个时候我们使用