代码改变世界

[LeetCode] 973. K Closest Points to Origin_Medium tag: Sort, heap, quickSort

2021-07-29 07:10  Johnson_强生仔仔  阅读(35)  评论(0编辑  收藏  举报

Given an array of points where points[i] = [xi, yi] represents a point on the X-Y plane and an integer k, return the k closest points to the origin (0, 0).

The distance between two points on the X-Y plane is the Euclidean distance (i.e., √(x1 - x2)2 + (y1 - y2)2).

You may return the answer in any order. The answer is guaranteed to be unique (except for the order that it is in).

 

Example 1:

Input: points = [[1,3],[-2,2]], k = 1
Output: [[-2,2]]
Explanation:
The distance between (1, 3) and the origin is sqrt(10).
The distance between (-2, 2) and the origin is sqrt(8).
Since sqrt(8) < sqrt(10), (-2, 2) is closer to the origin.
We only want the closest k = 1 points from the origin, so the answer is just [[-2,2]].

Example 2:

Input: points = [[3,3],[5,-1],[-2,4]], k = 2
Output: [[3,3],[-2,4]]
Explanation: The answer [[-2,4],[3,3]] would also be accepted.

Code:

1. Sort  T: O(n * lg n)

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        return sorted(points, key=lambda x: x[0]* x[0] +  x[1] * x[1])[:k]

 

2. Use heap   T: O(n * lg k)

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        return heapq.nsmallest(k, points, key = lambda x: x[0]* x[0] +  x[1] * x[1])

 

3. Use quick select/quick sort   T: average O(n), worst case O(n ^ 2)

    3.1. use a partition function to get the index of the pivot that it should be, [0, index - 1] < val(index) , [index + 1, n - 1] >= val(index)

    3.2. we recursive call until we have the index == k, then return

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        n = len(points)
        l, r = 0, n - 1
        while l < r:
            mid = self.partition(points, l, r)
            if mid == k:
                break
            if mid < k:
                l = mid + 1
            else:
                r = mid - 1
        return points[:k]
    
    
    def partition(self, points, l, r):
        pivot  = points[r]
        index = l
        for i in range(l, r):
            if self.compare(points[i], pivot) < 0:
                points[i], points[index] = points[index], points[i]
                index += 1
        points[index], points[r] = points[r], points[index]
        return index
        
    def compare(self, p1, p2):
        return (p1[0] ** 2 + p1[1] ** 2) - (p2[0] ** 2 + p2[1] ** 2)