378. Kth Smallest Element in a Sorted Matrix - cocoder39/coco39_LC GitHub Wiki

378. Kth Smallest Element in a Sorted Matrix

notes 2024: solution 1: variant of merge k sorted list. using heap to pick top k smallest

class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        m, n = len(matrix), len(matrix[0])
        min_heap = []
        for i in range(min(k, m)):
            min_heap.append((matrix[i][0], i, 0))
        
        heapq.heapify(min_heap)

        for i in range(k-1):
            _, row, col = heapq.heappop(min_heap)
            if col + 1 < len(matrix[row]):
                heapq.heappush(min_heap, (matrix[row][col+1], row, col+1))
        
        return min_heap[0][0]

solution 2: binary search iteration + sorted matrix traversal trick

class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        n = len(matrix)
        low, high = matrix[0][0], matrix[n-1][n-1]

        while low < high:
            mid = low + (high - low) // 2
            if self.countLessEqual(matrix, mid) < k:
                low = mid + 1
            else:
                high = mid

        return low
    
    def countLessEqual(self, matrix, target):
        n = len(matrix)
        count = 0
        row, col = n - 1, 0
        
        while row >= 0 and col < n:
            if matrix[row][col] <= target:
                count += row + 1
                col += 1
            else:
                row -= 1
        
        return count

================

Notes 2020

class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        n = len(matrix)
        col_idx_for_row = [0] * n
        min_heap = []
        for i in range(n):
            min_heap.append((matrix[i][0], i))
        
        heapq.heapify(min_heap)    
        
        res = matrix[0][0]
        while k > 0:
            res, row = heapq.heappop(min_heap)
            k -= 1
            col = col_idx_for_row[row] + 1
            if col < n:
                col_idx_for_row[row] = col
                heapq.heappush(min_heap, (matrix[row][col], row))
        return res

Time complexity: O(n) to build heap. O(k log n) to iterate the heap so overall time complexity is O(n + k logn)

=================================================

solution 1: heap sort time is O(max (n, k log n)), memory is O(n)

idea is similar to external sort. First sort each chunk, and then use pq to sort smallest element from each chunk

public int kthSmallest(int[][] matrix, int k) {
        int len = matrix.length;
        PriorityQueue<Tuple> pq = new PriorityQueue<Tuple>();
        for (int i = 0; i < len; i++) {
            pq.add(new Tuple(0, i, matrix[0][i]));
        }
        
        int res = matrix[0][0];
        for (int i = 0; i < k; i++) {
            Tuple tuple = pq.remove();
            res = tuple.val;
            int row = tuple.row + 1;
            if (row < len) {
                pq.add(new Tuple(row, tuple.col, matrix[row][tuple.col]));
            } 
        }
        return res;
    }
    
    class Tuple implements Comparable<Tuple> {
        int row;
        int col;
        int val;
        public Tuple(int row, int col, int val) {
            this.row = row;
            this.col = col;
            this.val = val;
        }
        
        @Override
        public int compareTo(Tuple other) {
            return this.val - other.val;
        }
    }

solution 2: binary search time complexity O(len * lg (matrix[len - 1][len - 1] - matrix[0][0])). For an integer array, the max diff is Integer.MAX_VALUE - Integer.MIN_VALUE = 2^32, so the complexity is O(len).

public int kthSmallest(int[][] matrix, int k) {
        int len = matrix.length;
        int low = matrix[0][0];
        int high = matrix[len - 1][len - 1];
        while (high - low > 1) {
            int mid = low + (high - low) / 2;
            int count = countLessOrEqual(matrix, mid);
            if (count >= k) {
                high = mid;
            } else {
                low = mid;
            }
        }
        
        return countLessOrEqual(matrix, low) < k ? high : low; // attention that there could be identical elements
    }
    
// O(len) time
    private int countLessOrEqual(int[][] matrix, int num) { // idea is similar to find a target in a sorted matrix
        int row = matrix.length - 1;
        int col = 0;
        int res = 0;
        while(row >= 0 && col < matrix.length) {
            if (matrix[row][col] > num) {
                row--;
            } else {
                col++;
                res += row + 1;
            }
        }
        return res;
    }

each time increase at most 2 elements, pop out 1 element. the size is no more than k. time is O(k log k), memory is O(k)

int kthSmallest(vector<vector<int>>& matrix, int k) {
        int n = matrix.size();
        auto cmp = [&matrix](pair<int, int> p1, pair<int, int> p2){
            return matrix[p1.first][p1.second] > matrix[p2.first][p2.second];
        };
        priority_queue<pair<int, int>, vector<pair<int,int>>, decltype(cmp)> pq(cmp);
        
        pq.push(make_pair(0, 0));
        while (k > 1 && ! pq.empty()) {
            auto p = pq.top();
            pq.pop();
            k--;
            
            if (p.first + 1 < n) {
                pq.push(make_pair(p.first + 1, p.second));
            }
            if (p.first == 0 && p.second + 1 < n) {
                pq.push(make_pair(0, p.second + 1));
            }
        }
        auto p = pq.top();
        return matrix[p.first][p.second];
    }
⚠️ **GitHub.com Fallback** ⚠️