np.argmax(input,axis)和tf.argmax(input,axis)分别是numpy和TensorFlow底下的求最大值索引的方法,用法基本一致,只有默认情况下有细微差别,以及传入的值略有不同,分别是array和tensor。

说白了,是不同模块下的相同方法。。只是不同模块下,数据类型不一致而已。。

 

一、np.argmax(input,axis)的使用

tf.argmax(input,axis),根据axis取值的不同返回每行或者每列(axis上比较最大值的索引

 

1.数组长度一致时,2维数组:

test = np.array([ [1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])

print(np.argmax(test))     

np.argmax(test, 0)

np.argmax(test, 1)

#输出:

9

[3, 3, 1]

[2, 2, 0, 0]

axis = 0:列最大索引

axis=0时,比较每一列元素,记录每一列最大元素所在的索引,最后输出每一列最大元素所在的索引数组。

test[0] = array([1, 2, 3])
test[1] = array([2, 3, 4])
test[2] = array([5, 4, 3])
test[3] = array([8, 7, 2])
# output   :       [3, 3, 1]

axis = 1:行最大索引

axis=1时,比较每一行元素,记录每一行最大元素所在的索引,最后返回每一行最大元素所在的索引数组。

test[0] = array([1, 2, 3])  #2
test[1] = array([2, 3, 4])  #2
test[2] = array([5, 4, 3])  #0
test[3] = array([8, 7, 2])  #0

 

2.数组长度一致时,n维数组:数组的shape很重要!

test = np.array([

                 [[19, 2, 3],

                  [2, 21, 2]],

                 [[5, 4, 3],

                  [1, 2, 3]],

                 [[5, 4, 6],

                  [1, 2, 3]],

                 [[15, 14, 13],

                  [11, 12, 3]]

                 ])

# 本例中,

# test形状是4*2*3,这个特别重要,axis=0,就是4个同一位置的元素比较,axis=1就是2个元素比较,axis=2就是3个元素比较。

# 再举个例子,

# test形状是3*7*5*10,这个特别重要,axis=0,就是3个同一位置的元素比较,axis=1就是7个元素比较,axis=2就是5个元素比较,axis=3就是10个元素比较。

 

axis=None或省略

print(np.argmax(test, axis=None)) 

#输出:4

# axis=None和省略结果相同,直接当成一维数组来查,21最大,是第5个元素,从0开始,对应的下标是4。

axis=0:

print(np.argmax(test, 0)) 

#输出:

[[0 3 3]

 [3 0 1]]

# axis = 0,其实是在第0维,也就是shape的第一个数4对应的那一维,比较4个元素的值。输出的是shape除了4之外的2*3的数组。

# 本例中,第一个元素0是4个元素19,5,5,15比较时max值19对应的索引,第二个元素3是4个元素2,4,4,14比较时max值14对应的索引……

axis = 1:

print(np.argmax(test, 1))

#输出:

[[0 1 0]

 [0 0 0]

 [0 0 0]

 [0 0 0]]

# axis = 1,其实是在第1维,也就是shape的第二个数2对应的那一维,比较2个元素的值。输出的是shape除了2之外的4*3的数组。

# 本例中,第一个元素0是2个元素19,2比较时max值19对应的索引,第二个元素1是2个元素2,21比较时max值21对应的索引,第三个元素0是2个元素3,2比较时max值3对应的索引……

axis = 2:

print(np.argmax(test, 2))

#输出:

[[0 1]

 [0 2]

 [2 2]

 [0 1]]

# axis = 2,其实是在第2维,也就是shape的第三个数3对应的那一维,比较3个元素的值。输出的是shape除了3之外的4*2的数组。

# 本例中,第一个元素0是3个元素19,2,3比较时max值19对应的索引,第二个元素1是3个元素2,21,2比较时max值21对应的索引,第三个元素0是3个元素5, 4, 3比较时max值5对应的索引,第四个元素2是3个元素1, 2, 3比较时max值3对应的索引……

 

3.数组长度不一致时:

axis最大值为数组维数-1,超过则报错。参考n维数组的例子,就是在每一个axis上比较的,很明显超过维度没有意义。

不一致时,axis=0的比较也就变成了每个数组的和的比较。【这个不理解,有问题?】

 

二、tf.argmax(input,axis)的使用

test = tf.Variable([

          [1, 2, 3],

                    [2, 13, 4],

                    [5, 4, 3],

                    [1, 2, 7]])

 

print(tf.argmax(test))  # 这个与np.argmax不同,默认axis=None或省略,与axis=0的结果相同。

print(tf.argmax(test, 0))  # 与np.argmax相同,tensor形式

print(tf.argmax(test, 1)) # 与np.argmax相同,tensor形式

# 输出:

tf.Tensor([2 1 3], shape=(3,), dtype=int64)

tf.Tensor([2 1 3], shape=(3,), dtype=int64)

tf.Tensor([2 1 0 2], shape=(4,), dtype=int64)

 

 

 

参考:

https://blog.csdn.net/weixin_44810016/article/details/91492069

https://blog.csdn.net/u012300744/article/details/81240580

https://blog.csdn.net/qq575379110/article/details/70538051/

 

posted on 2020-10-11 10:48  西伯尔  阅读(412)  评论(1编辑  收藏  举报