tf.nn.top_k

tf.nn.top_k(input, k, name=None)
返回 input 中每行最大的 k 个数,并且返回它们所在位置的索引

sample

import tensorflow as tf
import numpy as np
 
input = tf.constant(np.random.rand(3,4))
k = 2
output = tf.nn.top_k(input, k)
with tf.Session() as sess:
    print(sess.run(input))
    print(sess.run(output))

output

[[0.02285388 0.18911185 0.82657187 0.8928524 ]
 [0.35208592 0.23944607 0.48238538 0.22885373]
 [0.15866412 0.14843213 0.70295158 0.30709085]]
TopKV2(values=array([[0.8928524 , 0.82657187],
       [0.48238538, 0.35208592],
       [0.70295158, 0.30709085]]), indices=array([[3, 2],
       [2, 0],
       [2, 3]], dtype=int32))

posted @ 2019-07-17 14:14  JohnRed  阅读(378)  评论(0编辑  收藏  举报