[Leetcode]215. Kth Largest Element in an Array
这是Leetcode第215题,求无序数组第K大的数。
求第K大/小的数也是一个经典的问题了,一般来说有两种方法:堆思想和快排思想。其时间复杂度分别达到\(O(NlogK)\)和\(O(N)\)。我们先分析这两种算法,然后分析一个优化算法。
堆
一般来说,求第K大数使用最小堆,求K小数使用最大堆。时间复杂度都是\(O(NlogK)\)。
其思想是:维护一个K大小的最小堆,对于数组中的每一个元素判断与堆顶的大小,若堆顶较大,则不管,否则,弹出堆顶,将当前值插入到堆中。每一次更新堆的时间是\(logK\),而一共有N个数,所以是O(NlogK)。
代码如下:
import heapq
class Solution(object):
def findKthLargest(self, nums, k):
min_heap = nums[:k]
heapq.heapify(min_heap) # create a min-heap whose size is k
for num in nums[k:]:
if num > min_heap [0]:
heapq.heapreplace(min_heap , num)
# or use:
# heapq.heappushpop(min_heap, num)
return min_heap [0]
快排分治
利用快速排序的思想,从数组S中随机找出一个元素X,把数组分为两部分Sa和Sb。Sa中的元素大于等于X,Sb中元素小于X。这时有两种情况:
- Sa中元素的个数小于k,则Sb中的第k-|Sa|个元素即为第k大数;
- Sa中元素的个数大于等于k,则返回Sa中的第k大数。时间复杂度近似为\(O(n)\)
第一趟快排没找到,时间复杂度为\(O(n)\),第二趟也没找到,时间复杂度为O(n/2),...,第k趟找到,时间复杂度为\(O(n/2k)\),所以总的时间复杂度为\(O(n(1+1/2+….+1/2k))=O(n)\)。但在最差情况下,时间复杂度则为\(O(N^2)\)。
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
lo,hi = 0,len(nums)-1
while lo<hi:
j = self.partition(nums,lo,hi)
if j==k-1:return nums[j]
elif j<k-1: lo = j+1
else: hi = j-1
return nums[k-1]
def partition(self,a,lo,hi):
i=lo+1;j=hi
v = a[lo]
while True:
while a[i]>=v and i<hi:
i+=1
while v>=a[j] and j>lo:
j-=1
if (i>=j):break
a[i],a[j] = a[j],a[i]
a[lo],a[j] = a[j],a[lo]
return j
BFPRT算法
在上一种方法之上,加上一个筛选划分元素的过程,就能把最坏时间复杂度降到\(O(n)\)。筛选的过程就是把所有的数等分成很多小段,然后求所有小段的中间值。构成一个由所有中间值组成的段,然后再取中间值,作为划分元素。即中间值的中间值作为划分元素。取中间值可以先任选一种排序方法排序之后选择,因为每一小段的长度很短,不是影响复杂度的主要因素;取中间值的中间值,利用递归的方法调用自身即可。
这样就可以把最坏时间复杂度降到\(O(n)\)了,复杂度证明比较繁琐。
BFPRT的本质是对划分值的选择进行优化,不再是随机选取划分值。
具体流程如下:
1、将数组分为5个一组,不足5个的自动成一组。划分组所用时间为O(1)。
2、将每个组进行组内排序,可用插入排序。因为排序只有5个数,时间复杂度可记为O(1),所有组都排序为O(N)。
3、得到每个组内的上中位数,然后将这些上中位数组成新的数组mediums。
4、求出mediums数组中的上中位数pvalue,不使用排序,用的是递归调用BFPRT的过程,求上中位数就是求mediums数组第mediums.size()/2小的数。
5、此时得到的pvalue就是选取的划分值,然后进行partition过程即可。
为什么要这样选取划分值,这是因为,假设数组长度为n,则mediums数组的长度为n/5,则得到的pvalue在medium中会有n/10的数比其小,而这n/10的数,在自己的5个数的小组中,又会有3个数比pvalue小,所以,至少有n/10*3即3n/10个数比pvalue小,至多有7n/10个数比pvalue大,可以确定的淘汰掉3n/10的数。这样划分比较均衡。
6、刚才拿到pvalue划分值之后,进行partition过程,会返回等于区的边界下标。
7、如果k在等于的范围内,则返回pvalue;k在等于区的左边,则递归调用左边小于区的部分;k在等于区的右边,则递归调用大于区的部分。
class Solution:
def findKthSmallest(self, nums, k) :
return self.bfprt(nums,0,len(nums)-1,k-1)
def bfprt(self,nums,lo,hi,k):
if lo == hi:
return nums[lo]
if hi-lo <= 5:
return sorted(nums[lo:hi+1])[k-lo]
pivot = self.medianOfMedians(nums,lo,hi)
ind = lo+nums[lo:hi+1].index(pivot)
range = self.partition(nums,lo,hi,ind)
if k>=range[0] and k<=range[1]:
return nums[k]
elif k < range[0]:
return self.bfprt(nums,lo,range[0]-1,k)
else:
return self.bfprt(nums,range[1]+1,hi,k)
def medianOfMedians(self,nums,lo,hi):
medians = []
new = nums[lo:hi+1]
for i in range(0,len(new),5):
cur = sorted(new[i:i+5])
mid = cur[len(cur)//2]
medians.append(mid)
return self.bfprt(medians,0,len(medians)-1,len(medians)//2)
def partition(self, a, lo, hi,ind):
a[ind],a[lo] = a[lo],a[ind]
lt,gt = lo,hi
i = lo+1
v = a[lo]
while i<=gt:
if a[i] < v:
a[i],a[lt] = a[lt],a[i]
i+=1
lt+=1
elif a[i]>v:
a[i],a[gt] = a[gt],a[i]
gt -= 1
else:
i+=1
return [lt,gt]
其中partition方法不妨看做快排三向切分中的partition。为了消除重复值的影响,我们返回的pivot的下界和上界,即第一个pivot的索引\(lt\)以及最后一个pivot的索引\(gt\)。在一次partition后,\(lt\)左侧均为小于pivot的值,\(gt\)右侧均为大于pivot的值。
扩展:
求一个无序数组的中位数。
实际上即求数组的topN问题,其中若数组长度Length为奇数,N为Length//2+1;若数组长度为偶数,N为Length//2+1,以及Length//2的数的均值。
注意,在这种问题中,使用堆求解,实际上就是\(NlogN\)的复杂度了,所以一般使用快排或者BFPRT算法。