493. Reverse Pairs - cocoder39/coco39_LC GitHub Wiki

493. Reverse Pairs

Option 1: BIT

class Solution:
    def reversePairs(self, nums: List[int]) -> int:
        sorted_arr = sorted(list(set(nums)))
        ranks = {num: i+1 for i, num in enumerate(sorted_arr)}
        k = len(ranks)
        binaryIndexTree = [0] * (k+1)
        
        def update(index):
            while index <= k:
                binaryIndexTree[index] += 1
                index += index & (-index)
        
        def query(index):
            _sum = 0
            while index > 0:
                _sum += binaryIndexTree[index]
                index -= index & (-index)
            return _sum
        
        # return rank of the largest number that is no greater than target 
        def getRank(target): 
            low, high = 0, k - 1
            while low + 1 < high:
                mid = low + (high - low) // 2
                if sorted_arr[mid] <= target:
                    low = mid
                else:
                    high = mid
            
            if sorted_arr[high] <= target:
                return high + 1
            if sorted_arr[low] <= target:
                return low + 1
            return 0
            
        
        count = 0
        for num in nums[::-1]:
            rank = getRank((num+1) // 2 - 1)
            count += query(rank)
            update(ranks[num])
        return count

option 2: merge sort

class Solution:
    def reversePairs(self, nums: List[int]) -> int:
        
        def mergeSort(left, right):
            if left >= right:
                return
            
            mid = left + (right - left) // 2
            mergeSort(left, mid)
            mergeSort(mid+1, right)
            merge(left, mid, mid+1, right)
        
        def merge(left1, right1, left2, right2):
            count = 0
            p1, p2 = left1, left2
            while p1 <= right1 and p2 <= right2:
                if index_num_pairs[p1][1] <= 2 * index_num_pairs[p2][1]:
                    counters[index_num_pairs[p1][0]] += count
                    p1 +=1
                else:
                    count += 1
                    p2 += 1
            
            if p2 > right2:
                for i in range(p1, right1+1):
                    counters[index_num_pairs[i][0]] += count
            
            sorted_arr = []
            p1, p2 = left1, left2
            while p1 <= right1 and p2 <= right2:
                if index_num_pairs[p1][1] <= index_num_pairs[p2][1]:
                    sorted_arr.append(index_num_pairs[p1])
                    p1 += 1
                else:
                    sorted_arr.append(index_num_pairs[p2])
                    p2 += 1
            
            if p1 > right1:
                for i in range(p2, right2+1):
                    sorted_arr.append(index_num_pairs[i])
            if p2 > right2:
                for i in range(p1, right1+1):
                    sorted_arr.append(index_num_pairs[i])
            
            index_num_pairs[left1: right2+1] = sorted_arr[:]
        
        n = len(nums)
        index_num_pairs = [(i, num) for i, num in enumerate(nums)]
        counters = [0] * n
        
        mergeSort(0, n-1)
        return sum(counters)

Option 3: BST

T = O(N^2) for skewed BST (EG, build tree from sorted array)

class TreeNode:
    def __init__(self, val):
        self.val = val
        self.count = 1 # number of nodes in left subtree + itself
        self.left = None
        self.right = None

class Solution:
    def reversePairs(self, nums: List[int]) -> int:
        if not nums:
            return 0
        
        # return number of nodes that are less than target
        def search(root, target):
            if not root:
                return 0
            
            cur = root
            count = 0
            while True:
                if cur.val < target:
                    count += cur.count
                    if cur.right:
                        cur = cur.right
                    else:
                        return count
                else:
                    if cur.left:
                        cur = cur.left
                    else:
                        return count
            
        def insert(root, val):
            if not root:
                root = TreeNode(val)
                return
            
            cur = root
            while True:
                if cur.val >= val:
                    cur.count += 1
                    if cur.left:
                        cur = cur.left
                    else:
                        cur.left = TreeNode(val)
                        break
                else:
                    if cur.right:
                        cur = cur.right
                    else:
                        cur.right = TreeNode(val)
                        break
            
        root = TreeNode(nums[-1])
        count = 0
        for i in range(len(nums)-2, -1, -1):
            count += search(root, nums[i] / 2)
            insert(root, nums[i])
        return count