【Leetcode 堆、快速选择、Top-K问题 BFPRT】有序矩阵中第K小的元素(378)

题目

给定一个 n x n 矩阵,其中每行和每列元素均按升序排序,找到矩阵中第k小的元素。
请注意,它是排序后的第k小元素,而不是第k个元素。

示例:

matrix = [
   [ 1,  5,  9],
   [10, 11, 13],
   [12, 13, 15]
],
k = 8,

返回 13。

说明:
你可以假设 k 的值永远是有效的, 1 ≤ k ≤ n2 。

解答

这个问题和Leetcode 215笔记非常相似,可以用相同的几种思路解决掉。其中BFPRT时间复杂度O(N)

但这个题的输入是一个有序的矩阵,应该是有更好的办法吧!?找一圈没找到,有时间再来看。

思路:
1,全部收入列表,排序,取值。O(N·log(N))
2,维护一个大小为 k 的堆,元素大于等于堆顶负数入堆,堆顶就是第k小。O(N·log(k))
3,快速选择。最好O(N),最坏O(N^2)
4,BFPRT。O(N)

注:N表示元素个数,即n^2个

通过代码如下:

import random
from heapq import *

class Solution:
    # 排序
    # def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
    #     l = []
    #     for m in matrix:
    #         l.extend(m)
    #     return sorted(l)[k-1]

    # 快速选择
    # def kthSmallest(self, matrix, k):
    #     nums = []
    #     for m in matrix:
    #         nums.extend(m)

    #     def partition(left, right, base):
    #         temp = nums[base]
    #         nums[base], nums[right] = nums[right], nums[base]  # 基准和末尾元素互换

    #         max_index = left
    #         for i in range(left, right):  # 把所有小于基准的移到左边
    #             if nums[i] < temp:
    #                 nums[max_index], nums[i] = nums[i], nums[max_index]
    #                 max_index += 1

    #         nums[right], nums[max_index] = nums[max_index], nums[right]  # 基准归位
    #         return max_index

    #     def select(left, right, k_smallest):
    #         """在 nums[left, right] 找第k小的元素"""
    #         if left == right:  # 递归终止条件
    #             return nums[left]
    #         pivot_index = random.randint(left, right)  # 随机选择基准(比固定选第一个要好)
    #         base_index = partition(left, right, pivot_index)  # 选第一个(left)为基准,并归位。
    #         if base_index == k_smallest:  # 判断目前已归位的基准,是不是第k_smallest位
    #             return nums[k_smallest]
    #         elif k_smallest < base_index:  # go to 左半部分
    #             return select(left, base_index - 1, k_smallest)
    #         else:  # go to 右半部分
    #             return select(base_index + 1, right, k_smallest)

    #     return select(0, len(nums) - 1, k-1)  # 第k大,是第n-k小

    # 堆
    # def kthSmallest(self, matrix, k):
    #     nums = []
    #     for m in matrix:
    #         nums.extend(m)
    #     hq = []
    #     for x in nums:
    #         if len(hq) < k:
    #             heappush(hq, -x)
    #         elif -x >= hq[0]:
    #             heapreplace(hq, -x)
    #     return -heappop(hq)


# BFPRT解法
# Time: O(n), Space: O(n)
class Solution:
    def kthSmallest(self, matrix, k: int) -> int:
        nums = []
        for row in matrix:
            nums.extend(row)

        return self.BFPRT(nums, 0, len(nums)-1, k-1)

    def BFPRT(self, nums, left, right, K):
        """BFPRT算法"""
        if left == right:
            return nums[left]  # 不是排序,这要返回值

        base = self.medianOfMedians(nums, left, right)  # 找基准
        base_index = self.partition(nums, left, right, nums.index(base))  # 基准归位

        if base_index == K:
            return nums[base_index]
        elif base_index > K:
            return self.BFPRT(nums, left, base_index - 1, K)  # 递归左边。不是快排,这要返回值
        else:
            return self.BFPRT(nums, base_index + 1, right, K)  # 递归右边

    def medianOfMedians(self, nums, left, right):
        """找中位数基准"""
        length = right - left + 1
        offset = 0 if length % 5 == 0 else 1  # 最后不够5个的算一组
        groups = length // 5 + offset

        medians = []
        for i in range(groups):
            start = left + i * 5
            end = start + 4
            medians.append(self.get_median(nums[start: min(end, right) + 1]))
        return self.BFPRT(medians, 0, groups - 1, groups // 2)  # 这里递归BFPRT,保证得到的是“准确中位数”(30% ~ 70%); 而非“近似中位数” 防止时间复杂度退化

    def get_median(self, nums):
        """找5个数的中位数"""
        return sorted(nums)[len(nums) // 2]  # 常数级别

    def partition(self, nums, left, right, base):
        """常规partition"""
        if left >= right:
            return
        temp = nums[base]
        nums[base], nums[right] = nums[right], nums[base]

        max_index = left
        for i in range(left, right):
            if nums[i] <= temp:
                nums[i], nums[max_index] = nums[max_index], nums[i]
                max_index += 1
        nums[max_index], nums[right] = nums[right], nums[max_index]
        return max_index
posted @ 2019-12-14 16:00  961897  阅读(362)  评论(0编辑  收藏  举报