LC 0973 [M] K Closest Points to Origin - ALawliet/algorithms GitHub Wiki

the classic K elements problem

  • use max heap, so -priority for descending order
max heap O(nlogk), O(k): pop largest => smallest remain
quicksort: O(nlogn)
quickselect: O(n) average-case, O(n^2) worst-case

because it is distance to origin and we just need a priority for the queue, we can simplify the formula

sqrt( (x2-x1)^2 + (y2-y1)^2 ) 
=> sqrt( (x2-0)^2 + (y2-0) ) # to 0
=> x2^2 + y2^2 # don't need sqrt

exponent: x ** 2 or pow(x,2)


from queue import PriorityQueue

def distance(x, y):
    return x**2 + y**2 

class Solution:
    def kClosest(self, points: List[List[int]], K: int) -> List[List[int]]:
        maxH = PriorityQueue(K+1)
        for point in points:
            x, y = point
            priority = distance(x, y)
            item = (-priority, point)
            maxH.put(item)
            if maxH.full():
                maxH.get()
        # return list(map(lambda item: item[1], maxH.queue))
        return [point for (priority, point) in maxH.queue]

if we partition around the target K, then when pivot == K, everything to the left up to :K is sorted to be the closest

quickselect looks like a binary search

def distance(point):
    x, y = point
    return x**2 + y**2

class Solution:
    def kClosest(self, points, K):
        n = len(points)
        l = 0
        r = n - 1

        x = K        
        # self.quicksort(points, l, r, x)
        self.quickselect(points, l, r, x)
            
        return points[:K] # up to K
    
    def quicksort(self, points, l, r, x):
        if l <= r:
            p = self.partition(points, l, r)
            if p == x:
                return # everything to the left of K is now sorted
            elif p < T:
                self.quicksort(points, p+1, r, T)
            elif p > T:
                self.quicksort(points, l, p-1, T)
                
    def quickselect(self, points, l, r, x):
        while l <= r:
            p = self.partition(points, l, r)
            if p == x:
                return # everything to the left of K is now sorted
            elif p < x:
                l = p + 1
            elif p > x:
                r = p - 1
    
    def partition(self, A, l, r):
        pivot = A[r] # compare against the value, not the index
        i = l # temp
        for j in range(l, r): # from l to r
            if distance(A[j]) <= distance(pivot): # A[j] <= pivot
                A.swap(i, j)
                i += 1
        A.swap(i, r)
        return i