tf.argmax()函数作用

tf.argmax()函数原型:

def argmax(input,
           axis=None,
           name=None,
           dimension=None,
           output_type=dtypes.int64)

作用是返回每列/行的最大值的索引。

input是一个张量,

axis是0或1,0返回各列最大值索引,1返回各行最大值索引。

其他3个参数不常用,常用写法是 a = tf.argmax(tensor, 1)。

 

import tensorflow as tf
sess = tf.InteractiveSession()

a = tf.constant([[12, 3, 9],
                 [3, 6, 13]]) 

b_1 = tf.argmax(a, 0)   # 返回ndarry,元素是每列的最大值索引
b_2 = tf.argmax(a, 1)

print(b_1)   # >>array([0, 1, 1], dtype=int64)
print(b_2)   # >>array([0, 2], dtype=int64)

 

posted @ 2020-02-23 20:29  panday  阅读(1084)  评论(0编辑  收藏  举报