tf.where(a,b,c)函数:
功能:当a输出结果为true时,tf.where(a,b,c)函数会选择b值输出。
当a输出结果为false时,tf.where(a,b,c)函数会选择c值输出。
例子:
import tensorflow as tf
v1=tf.constant([1.0,2.0,3.0,4.0])
v2=tf.constant([4.0,3.0,2.0,1.0])
with tf.Session() as sess:
init=tf.global_variables_initializer()
sess.run(init)
print(sess.run(tf.greater(v1,v2)))
print(sess.run(tf.where(tf.greater(v1,v2),v1,v2)))
结果:
[False False True True]
[ 4. 3. 3. 4.]