315. Count of Smaller Numbers After Self - cocoder39/coco39_LC GitHub Wiki

315. Count of Smaller Numbers After Self

Update on 1/23/2021

How to tackle this kind of problem?

usually 2 patterns to divide and conquer:

use f(i, j) to solve the problem with subarray nums[i:j]

  1. f(i, j) = f(i, j-1) + C C denotes the sub problem of "count of smaller numbers after nums[j]"
  • Liner scan requires O(N) time complexity so overall time complexity is O(N^2)
  • binary search requires O(log N) time complexity to find the count from a sorted array. However, to maintain a sorted array we will need to insert element into the sorted array as we move, which has time complexity of O(N) so overall time complexity is still O(n^2)
  • balanced BST and BIT come into play as we can perform search and insert with O(log N) time complexity so overall time complexity can be reduced to O(N log N)
  1. f(i, j) = f(i, m) + f(m+1, j) + C C denotes the subproblem of "count of count of smaller numbers in nums[m+1:j] for each number in nums[i:m]"
  • liner scan both subarrays can take O(n^2) time complexity to solve C
  • if both subarrays are sorted, then time complexity can be reduced to O(N)
    • problem structure f(i, j) = f(i, m) + f(m+1, j) + C
    • maintain sorted subarrays
    • those 2 hints suggest us to use merge sort

Reference: a good post for the thinking process

===================================

Notes 2020:

Brute force takes O(n^2) then we can consider O(NlogN) approaches eg, divide and conquer, sort, binary search...

Option 1: leverage merge sort

One application of merge sort is inversion count problem. This problem is a variant of inversion count problem.

The key point is the merge() operation. when merging [left1, left2, left3 ...] with [right1, right2, right3..]

  • if left1 > right1, then left2, left3 ... are all bigger than right1, so right1 will contribute to the count of each of those numbers

T = O(N log N) and space is o(N)

class Solution:
    def countSmaller(self, nums: List[int]) -> List[int]:
        
        def mergeSort(low, high):
            if low >= high:
                return
            
            mid = low + (high - low) // 2
            mergeSort(low, mid)
            mergeSort(mid+1, high)
            return merge(low, mid, mid+1, high)
        
        def merge(left1, right1, left2, right2):
            sorted_array = []
            
            p1, p2 = left1, left2
            count = 0
            while p1 <= right1 and p2 <= right2:
                if pairs[p1][1] <= pairs[p2][1]:
                    sorted_array.append(pairs[p1])
                    counters[pairs[p1][0]] += count
                    p1 += 1
                else:
                    sorted_array.append(pairs[p2])
                    p2 += 1
                    count += 1
            
            if p1 > right1:
                for i in range(p2, right2+1):
                    sorted_array.append(pairs[i])
            elif p2 > right2:
                for i in range(p1, right1+1):
                    sorted_array.append(pairs[i])
                    counters[pairs[i][0]] += count
                
            pairs[left1: right2+1] = sorted_array[:]
        
        n = len(nums)
        pairs = list(enumerate(nums)) # pair of (index, num)
        counters = [0] * n
        
        mergeSort(0, n-1)
        return counters

Option 2: convert to count the frequency of numbers -> binary index tree

time complexity of update and query are both O(logN) so overall time complexity is O(N logN)

class Solution:
    def countSmaller(self, nums: List[int]) -> List[int]:
        ranks = {num: i+1 for i, num in enumerate(sorted(list(set(nums))))}
        k = len(ranks)
        binaryIndexTree = [0] * (k+1)
        
        def update(index, increment):
            while index <= k:
                binaryIndexTree[index] += increment
                index += index & (-index) 
            
        def query(index):
            sum_ = 0
            while index > 0:
                sum_ += binaryIndexTree[index]
                index -= index & (-index) 
            return sum_
        
        res = []
        for i in range(len(nums)-1, -1, -1):
            index = ranks[nums[i]]
            res.append(query(index - 1))
            update(index, 1)
        return res[::-1]

Option 3: binary search -> binary search tree

Binary search: For each iteration, it takes O(logN) to perform binary search but O(N) to insert so overall time complexity is O(N^2). This can be optimized with using binary search tree which takes O(log N) to search and O(log N) to insert. Actually the search and insertion are combined in the BST solution.

class Solution:
    def countSmaller(self, nums: List[int]) -> List[int]:
        result = []
        sorted_arr = []
        for num in nums[::-1]:
            index = bisect.bisect_left(sorted_arr, num)
            result.append(index)
            sorted_arr.insert(index, num)
        return result[::-1]

each node maintains a count which records how many nodes in its left subtree and itself. When val is less or equal to node.val, we increase the node.count by 1. when val is greater than node.val, we know that the val will be inserted to node's right side and the counter needs to be increased by node.count

Best case T = O(N log N). For a skewed tree (eg, input is sorted) T can be O(n^2)

class TreeNode:
    def __init__(self, val):
        self.val = val
        self.count = 1 # number of node in left subtree that are less or equal to val
        self.left = None
        self.right = None

class Solution:
    def countSmaller(self, nums: List[int]) -> List[int]:
        
        def insert(head, val):
            cur = head
            count = 0
            while True:
                if val <= cur.val:
                    cur.count += 1
                    if cur.right:
                        cur = cur.right
                    else:
                        node = TreeNode(val)
                        cur.right = node
                        return count
                else:
                    count += cur.count
                    if cur.left:
                        cur = cur.left
                    else:
                        node = TreeNode(val)
                        cur.left = node
                        return count   
        if not nums:
            return []
        
        head = TreeNode(nums[-1])
        res = [0]
        for i in range(len(nums)-2, -1, -1):
            res.append(insert(head, nums[i]))
        return res[::-1]

=========================================================================================================

a reference with many methods

time complexity is O(n * log n), space complexity is O(n) where n is nums.size()

class Solution {
public:
    vector<int> countSmaller(vector<int>& nums) {
        vector<int> sorted = nums;
        sort(sorted.begin(), sorted.end());
        for (int i = 0; i < nums.size(); i++)
            nums[i] = findIndex(sorted, nums[i]) + 1;
            
        vector<int> bit(nums.size() + 1);
        vector<int> res(nums.size());
        for (int i = nums.size() - 1; i >= 0; i--) {
            res[i] = query(bit, nums[i] - 1);
            add(bit, nums[i], 1);
        }
        return res;
    }
private:
    int findIndex(vector<int>& sorted, int n) {
        int start = 0;
        int end = sorted.size() - 1;
        while (start + 1 < end) {
            int mid = start + (end - start) / 2;
            if (sorted[mid] == n)
                return mid;
            else if (sorted[mid] < n)
                start = mid + 1;
            else //sorted[mid] > n
                end = mid - 1;
        }
        
        if (sorted[start] == n) return start;
        if (sorted[end] == n)   return end;
        return -1;
    }
    
    void add(vector<int>& bit, int i, int val) {
        for (; i < bit.size(); i += i & -i)
            bit[i] += val;
    }
    
    int query(vector<int>& bit, int i) {
        int res = 0;
        for (; i > 0; i -= (i & -i))
            res += bit[i];
        return res;
    }
}; 
⚠️ **GitHub.com Fallback** ⚠️