在进行语义分割的二分类中,需要将预测值大于和小于0.5的logits分别标记为True和False。使用tf.equal(label, 0)只会判断改值是否为1。
使用tf.where(input, a, b)实现这个功能。
其中input是tensor+判断条件,判断得到True和False的一个sensor,它和a、b尺寸一致。
函数作用是将a中对应input中true的位置的元素值不变,其余元素进行替换,替换成b中对应位置的元素值。
tf.ones_like(label) 和 tf.zeros_like(label) 两个函数生成和label形状相同的两个纯1和纯0的tensor,具体实现代码:
t1 = tf.constant([-0.1, 0.3, -0.49, -0.02])
ones = tf.ones_like(t1)
zeros = tf.zeros_like(t1)
t2 = tf.where(t1>0, ones, zeros)
equal = tf.cast(t2, tf.bool)
cast_s = tf.cast(sign, tf.bool)
with tf.Session() as sess:
print(sess.run(t2))
print(sess.run(equal))
Output:
[0. 1. 0. 0.]
[False True False False]