tf.nn.in_top_k(predictions,targets,k,name=None)tf.nn.in\_top\_k(predictions, targets, k, name=None)tf.nn.in_top_k(predictions,targets,k,name=None)
predictionspredictionspredictions:你的预测结果(一般也就是你的网络输出值)大小是预测样本的数量乘以输出的维度。
targettargettarget:实际样本类别的标签,大小是样本数量的个数。
kkk:每个样本中前KKK个最大的数里面(序号)是否包含对应targettargettarget中的值。
import tensorflow as tf
X=tf.Variable([[0.4,0.2,0.3,0.1],[0.1,0.1,0.2,0.6],[0.7,0.1,0.1,0.1]])
Y=tf.Variable([2,1,1])
k=tf.placeholder(tf.int32,shape=None)
result=tf.nn.in_top_k(X,Y,k)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(sess.run(X))
print(sess.run(Y))
print(sess.run(result,feed_dict={k:1}))
print(sess.run(result,feed_dict={k:2}))
#结果为:
#[[0.4 0.2 0.3 0.1]
[0.1 0.1 0.2 0.6]
[0.7 0.1 0.1 0.1]]
#[2 1 1]
#[False False False]
#[ True False True]
分析一下结果
当k=1k=1k=1时,
XXX中[0.40.20.30.1][0.4 0.2 0.3 0.1][0.40.20.30.1]最大元素为0.40.40.4,索引为000,而BBB是222,不包含BBB,故FalseFalseFalse
[0.10.10.20.6][0.1 0.1 0.2 0.6][0.10.10.20.6]最大元素为0.60.60.6,索引为333,B是111,不包含BBB,故FalseFalseFalse
[0.70.10.10.1][0.7 0.1 0.1 0.1][0.70.10.10.1]最大元素为0.70.70.7,索引为000,B是111,不包含BBB,故FalseFalseFalse
当k=2k=2k=2时,
XXX中[0.40.20.30.1][0.4 0.2 0.3 0.1][0.40.20.30.1]最大的两个元素为0.4、0.30.4、0.30.4、0.3,索引为0、20、20、2,BBB为222,故TrueTrueTrue,
[0.10.10.20.6][0.1 0.1 0.2 0.6][0.10.10.20.6]最大两个元素为0.6、0.20.6、0.20.6、0.2,索引为3、23、23、2,BBB是111,不包含BBB,故FalseFalseFalse
[0.70.10.10.1][0.7 0.1 0.1 0.1][0.70.10.10.1]最大元素为0.7,0.10.7,0.10.7,0.1,索引为0、10、10、1,BBB是111,包含BBB,故TrueTrueTrue