tensorflow 里的tf.select(condition,a,b)解释及应用
首先参考tensorflow里的官网解释:http://www.tensorfly.cn/tfdoc/api_docs/python/control_flow_ops.html#less 里面是全英文比较费解。那么下面的解释对你应该有帮助:
condition:一个张量tensor,类型为bool
a :一个张量tensor,shape与condition一致,类型一般为
float32
, float64
, int32
, int64
.
b :一个张量tensor,类型和shape与a一致。
举例:
import tensorflow as tf
sess=tf.Session()
condition=[[True,False],[True,False]]
a=[[1,2],[3,4]]
b=[[5,6],[7,8]]
c=tf.select(condition,a,b)
print(sess.run(c))
输出:
[[1,6],[3,8]]
如果把condition改成[[True,
True],[True,False]]
输出变为:
[[1,2],[3,8]]
解释:a里对应condition中为True的位置在返回值中继续输出,b里对应condition中为False的位置在返回值中继续输出。
如果把condition改成[[True,True],[True,True]]
输出为a:
[[1,2],[3,4]]
[[5,6],[7,8]]