np.argmax(input,axis)和tf.argmax(input,axis)分别是numpy和TensorFlow底下的求最大值索引的方法,用法基本一致,只有默认情况下有细微差别,以及传入的值略有不同,分别是array和tensor。
说白了,是不同模块下的相同方法。。只是不同模块下,数据类型不一致而已。。
一、np.argmax(input,axis)的使用
tf.argmax(input,axis),根据axis取值的不同返回每行或者每列(在axis上比较)最大值的索引。
1.数组长度一致时,2维数组:
test = np.array([ [1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])
print(np.argmax(test))
np.argmax(test, 0)
np.argmax(test, 1)
#输出:
9
[3, 3, 1]
[2, 2, 0, 0]
axis = 0:列最大索引
axis=0时,比较每一列元素,记录每一列最大元素所在的索引,最后输出每一列最大元素所在的索引数组。
test[0] = array([1, 2, 3])
test[1] = array([2, 3, 4])
test[2] = array([5, 4, 3])
test[3] = array([8, 7, 2])
# output : [3, 3, 1]
axis = 1:行最大索引
axis=1时,比较每一行元素,记录每一行最大元素所在的索引,最后返回每一行最大元素所在的索引数组。
test[0] = array([1, 2, 3]) #2
test[1] = array([2, 3, 4]) #2
test[2] = array([5, 4, 3]) #0
test[3] = array([8, 7, 2]) #0
2.数组长度一致时,n维数组:数组的shape很重要!
test = np.array([
[[19, 2, 3],
[2, 21, 2]],
[[5, 4, 3],
[1, 2, 3]],
[[5, 4, 6],
[1, 2, 3]],
[[15, 14, 13],
[11, 12, 3]]
])
# 本例中,
# test形状是4*2*3,这个特别重要,axis=0,就是4个同一位置的元素比较,axis=1就是2个元素比较,axis=2就是3个元素比较。
# 再举个例子,
# test形状是3*7*5*10,这个特别重要,axis=0,就是3个同一位置的元素比较,axis=1就是7个元素比较,axis=2就是5个元素比较,axis=3就是10个元素比较。
axis=None或省略
print(np.argmax(test, axis=None))
#输出:4
# axis=None和省略结果相同,直接当成一维数组来查,21最大,是第5个元素,从0开始,对应的下标是4。
axis=0:
print(np.argmax(test, 0))
#输出:
[[0 3 3]
[3 0 1]]
# axis = 0,其实是在第0维,也就是shape的第一个数4对应的那一维,比较4个元素的值。输出的是shape除了4之外的2*3的数组。
# 本例中,第一个元素0是4个元素19,5,5,15比较时max值19对应的索引,第二个元素3是4个元素2,4,4,14比较时max值14对应的索引……
axis = 1:
print(np.argmax(test, 1))
#输出:
[[0 1 0]
[0 0 0]
[0 0 0]
[0 0 0]]
# axis = 1,其实是在第1维,也就是shape的第二个数2对应的那一维,比较2个元素的值。输出的是shape除了2之外的4*3的数组。
# 本例中,第一个元素0是2个元素19,2比较时max值19对应的索引,第二个元素1是2个元素2,21比较时max值21对应的索引,第三个元素0是2个元素3,2比较时max值3对应的索引……
axis = 2:
print(np.argmax(test, 2))
#输出:
[[0 1]
[0 2]
[2 2]
[0 1]]
# axis = 2,其实是在第2维,也就是shape的第三个数3对应的那一维,比较3个元素的值。输出的是shape除了3之外的4*2的数组。
# 本例中,第一个元素0是3个元素19,2,3比较时max值19对应的索引,第二个元素1是3个元素2,21,2比较时max值21对应的索引,第三个元素0是3个元素5, 4, 3比较时max值5对应的索引,第四个元素2是3个元素1, 2, 3比较时max值3对应的索引……
3.数组长度不一致时:
axis最大值为数组维数-1,超过则报错。参考n维数组的例子,就是在每一个axis上比较的,很明显超过维度没有意义。
不一致时,axis=0的比较也就变成了每个数组的和的比较。【这个不理解,有问题?】
二、tf.argmax(input,axis)的使用
test = tf.Variable([
[1, 2, 3],
[2, 13, 4],
[5, 4, 3],
[1, 2, 7]])
print(tf.argmax(test)) # 这个与np.argmax不同,默认axis=None或省略,与axis=0的结果相同。
print(tf.argmax(test, 0)) # 与np.argmax相同,tensor形式
print(tf.argmax(test, 1)) # 与np.argmax相同,tensor形式
# 输出:
tf.Tensor([2 1 3], shape=(3,), dtype=int64)
tf.Tensor([2 1 3], shape=(3,), dtype=int64)
tf.Tensor([2 1 0 2], shape=(4,), dtype=int64)
参考:
https://blog.csdn.net/weixin_44810016/article/details/91492069
https://blog.csdn.net/u012300744/article/details/81240580
https://blog.csdn.net/qq575379110/article/details/70538051/
作者:西伯尔
出处:http://www.cnblogs.com/sybil-hxl/
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。