tf.nn.in_top_k(predictions, targets, k, name=None)
函数意义:判断每一个标号是不是矩阵每一行数值排序前 k的数
输入参数:
1) predictions: 表示输入矩阵
2)targets: 表示输入矩阵的每一行的列编号 从0 开始
3)k;表示每行值排序后 前 K的数
例程如下:
def Test_Top():
input=tf.constant(np.random.rand(3,4),tf.float32)
k=2
#第一行第2个数,第二行第3个数,第三行第2个数是不是最大的前2个数
output=tf.nn.in_top_k(input,[2,3,2],k)
with tf.Session() as sess:
print(sess.run(input))
print('------')
print(sess.run(output))