numpy中的argmax、argmin、argwhere、argsort、argpartition函数

楔子

numpy中有几个以arg开头的函数,非常的神奇,因为它们返回的不是元素、而是元素的索引,我们来看一下用法,这里只以一维数组为例。

np.argmax

首先np.max是获取最大元素,那么np.argmax是做什么的呢?

import numpy as np

arr = np.array([3, 22, 4, 11, 2, 44, 9])
print(np.max(arr))  # 44
print(np.argmax(arr))  # 5

我们看到np.max是获取数组中最大的元素,np.argmax是获取数组中最大元素对应的索引。

同理还有np.argmin,np.min是获取数组中最小的元素,显然是2;np.argmin是获取数组中最小元素对应的索引,显然是4。

import numpy as np

arr = np.array([3, 22, 4, 11, 2, 44, 9])
print(np.min(arr))  # 2
print(np.argmin(arr))  # 4

np.argwhere

np.where我们算是经常使用了,先来复习一下它的用法吧。

import numpy as np

arr = np.array([1, 2, 3, 4, 5, 6, 7, 8])

# 如果元素大于4, 那么减去10; 否则扩大十倍
print(np.where(arr > 4, arr - 10, arr * 10))  # [10 20 30 40 -5 -4 -3 -2]

# 如果元素大于4, 那么保持不变, 否则变成4
print(np.where(arr > 4, arr, 4))  # [4 4 4 4 5 6 7 8]

和np.where作用类似的还有一个np.clip,来看一下。

import numpy as np

arr = np.array([1, 2, 3, 4, 5, 6, 7, 8])

# 小于2的换成2, 大于6的换成6, 一般在设置上下限的时候非常有用
print(np.clip(arr, 2, 6))  # [2 2 3 4 5 6 6 6]

那么np.where是做啥的呢?首先这个函数只接受一个参数,找出满足条件的元素对应的索引。

import numpy as np

arr = np.array([3, 4, 5, 6, 7])
print(np.argwhere(arr % 2 != 0))
"""
[[0]
 [2]
 [4]]
"""
print(np.argwhere(arr % 2 != 0).flatten())  # [0 2 4]

显然元素3、5、7在%2之后不为0,所以会筛选出它们的索引,因此是[0 2 4]。只不过默认不是一个一维数组,我们需要再调用一下flatten,将其扁平化。

np.argsort

np.sort是用来排序的,类似于Python的内置函数sorted。

import numpy as np

arr = np.array([4, 2, 3, 6, 5, 1])
print(np.sort(arr))  # [1 2 3 4 5 6] 

sort很容易,再来看看argsort。

import numpy as np

arr = np.array([4, 2, 3, 6, 5, 1])
print(np.sort(arr))  # [1 2 3 4 5 6]
print(np.argsort(arr))  # [5 1 2 0 4 3]

sort是将从小到大排序之后返回,argsort是返回从小到大排序之后元素对应的索引。比如:第一个元素是5,表示原来数组中索引为5的元素在排序之后应该排在第一个位置上。

因此,通过argsort我们可以选出topN的元素。

import numpy as np

arr = np.array([4, 2, 3, 6, 5, 1, 8, 9, 7])
print(arr[np.argsort(arr)[-3:]])  # [7 8 9]

# 当然sort本身也是可以的
print(np.sort(arr)[-3:])  # [7 8 9]

下面看一个问题,如果我想查看数组中每一个元素在排完序之后对应的索引该怎么办呢?

以数组:[3 2 1 4]为例,在排完序之后结果显然是[1 2 3 4],那么原来的元素3应该在索引为2的位置上、元素2在索引为1的位置上、元素1在索引为0的位置上、元素4在索引为3的位置上,所以我们希望得到一个数组[2 1 0 3],那么要怎么做?

import numpy as np

arr = np.array([88, 79, 86, 97, 89, 95, 84])

# 调用一次argsort显然是不够的, 它表示排完序之后原来的元素对应的索引
print(np.argsort(arr))  # [1 6 2 0 4 5 3]

# 如果我们连续调用两次argsort的话, 另外np.argsort(arr) <==> arr.argsort()
print(arr.argsort().argsort())  # [3 0 2 6 4 5 1]

# 此时就大功告成了
# 数组[3 0 2 6 4 5 1]表示:
#   arr中第一个元素88在排完序之后应该处于索引为3的位置
#   79在排完序之后应该处于索引为0的位置
#   ...

以88为例,显然它在排序之后索引为3,所以对arr.argsort()得到的数组再进行一次argsort即可得到我们想要的结果。这个可能有点绕,使用言语表达起来实在是不太容易,可以自己看着图尝试一下。

np.argpartition

argpartition类似于argwhere,但它只是局部排序,举例说明:

import numpy as np

arr = np.array([66, 15, 27, 33, 19, 13, 10])

"""
np.partition(arr, n)
找出arr中第n + 1小的元素(将arr排序之后索引n的元素), 然后返回一个新数组
并将原来数组中第n + 1小的元素放在新数组索引为n的地方, 保证左边的元素比它小, 右边的元素比它大
"""
print(np.partition(arr, 3))  # [15 13 10 19 27 33 66]
# 第4小的元素(排完序之后索引为3)显然是19, 那么将19放在索引为3的位置, 然后左边的元素比它小, 右边的元素比它大
# 至于两边的顺序则没有要求

# 虽然我们可以使用sort, 但是sort是全局排序
# 如果数组非常大, 我们只希望选择最小的10个元素, 直接通过np.partition(arr, 9)即可
# 然后如果排序的话, 只对这选出来的10个元素排序即可, 而无需对整个大数组进行排序

# 同理还可以从后往前找, 比如:
# np.partition(arr, -2)表示找到第2大的元素(将arr排序之后索引-2的元素), 放在数组索引为-2的地方
# 然后左边元素比它小, 右边元素比它大
print(np.partition(arr, -2))  # [13 10 27 15 19 33 66]
# 第2大的元素显然是33, 那么排在索引为-2的位置, 左边元素比它小, 右边元素比它大


# 然后argpartition不用想, 肯定是获取排序之后的索引
print(np.argpartition(arr, 3))  # [1 5 6 4 2 3 0]
print(np.argpartition(arr, -2))  # [5 6 2 1 4 3 0]
posted @ 2020-09-20 23:01  古明地盆  阅读(4536)  评论(2编辑  收藏  举报