327. Count of Range Sum - cocoder39/coco39_LC GitHub Wiki

327. Count of Range Sum

Notes 2020:

Similar to count smaller than self

BIT

class BinaryIndexTree:
    def __init__(self, n):
        self.n = n
        self.arr = [0] * (n+1)
    
    def getLeastSignificantBit(self, i: int) -> int:
        return i & (-i)
    
    def update(self, index: int) -> None:
        while index <= self.n:
            self.arr[index] += 1
            index += self.getLeastSignificantBit(index)
    
    def query(self, index: int) -> int:
        total = 0
        while index > 0:
            total += self.arr[index]
            index -= self.getLeastSignificantBit(index)
        return total

class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
        if not nums:
            return 0
        
        sums = [0] * len(nums)
        total = 0
        count = 0
        for i, num in enumerate(nums):
            total += num
            if lower <= total <= upper:
                count += 1
            sums[i] = total
        sortedSums = sorted(list(set(sums)))
                
        bit = BinaryIndexTree(len(sortedSums))
        for num in sums:            
            higherRank = self.getHigherRank(sortedSums, num - lower)
            if higherRank != -1:
                lowerRank = self.getLowerRank(sortedSums, num - upper)
                if lowerRank != -1:
                    count += bit.query(higherRank+1) - bit.query(lowerRank)
            
            rank = self.getRank(sortedSums, num)
            bit.update(rank+1)
        return count

    # get highest index of elements that are <= target
    def getHigherRank(self, sortedArr: List[int], target: int):
        low, high = 0, len(sortedArr) - 1
        while low + 1 < high:
            mid = low + (high-low) // 2
            if sortedArr[mid] <= target:
                low = mid
            else:
                high = mid
        
        if sortedArr[high] <= target:
            return high
        if sortedArr[low] <= target:
            return low
        return -1
    
    # get lowest index of elements that are >= target
    def getLowerRank(self, sortedArr: List[int], target: int):
        low, high = 0, len(sortedArr) - 1
        while low + 1 < high:
            mid = low + (high-low) // 2
            if sortedArr[mid] >= target:
                high = mid
            else:
                low = mid
        
        if sortedArr[low] >= target:
            return low
        if sortedArr[high] >= target:
            return high
        return -1
    
    # get rank of target
    def getRank(self, sortedArr: List[int], target: int):
        low, high = 0, len(sortedArr) - 1
        while low + 1 < high:
            mid = low + (high-low) // 2
            if sortedArr[mid] > target:
                high = mid
            elif sortedArr[mid] < target:
                low = mid
            else:
                return mid
        
        if sortedArr[low] == target:
            return low
        return high

Merge sort

class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
        if not nums:
            return 0
        
        n = len(nums)
        sum_ = 0
        sums = [0] * n
        count = 0
        for i, num in enumerate(nums):
            sum_ += num
            if lower <= sum_ <= upper:
                count += 1
            sums[i] = sum_
        
        return count + self.mergeSort(sums, 0, n-1, lower, upper)
        
    def mergeSort(self, sums: List[int], low: int, high: int, lowerBound: int, upperBound: int) -> int:
        if low >= high:
            return 0
        
        mid = low + (high - low) // 2
        left = self.mergeSort(sums, low, mid, lowerBound, upperBound)
        right = self.mergeSort(sums, mid+1, high, lowerBound, upperBound)
        count = self.merge(sums, low, mid, mid+1, high, lowerBound, upperBound)
        return left + right + count
    
    def merge(self, sums: List[int], 
              left1: int, right1: int, 
              left2: int, right2: int,
              low: int, high: int) -> int:
        
        count = 0
        a = b = left2
        for i in range(left1, right1+1):
            while a <= right2 and sums[a] - sums[i] < low:
                a += 1
            while b <= right2 and sums[b] - sums[i] <= high:
                b += 1
            count += b - a
                    
        sorted_arr = []
        p1, p2 = left1, left2
        while p1 <= right1 and p2 <= right2:
            if sums[p1] <= sums[p2]:
                sorted_arr.append(sums[p1])
                p1 += 1
            else:
                sorted_arr.append(sums[p2])
                p2 += 1
        
        if p1 > right1:
            for i in range(p2, right2+1):
                sorted_arr.append(sums[i])
        
        if p2 > right2:
            for i in range(p1, right1+1):
                sorted_arr.append(sums[i])
        
        sums[left1: right2+1] = sorted_arr[:]
        
        return count

======================================================================

compare this problem with count smaller than self, binary index tree can also solve this problem, and mergesort can be used to solve that problem as well.

tips: passing to end is sz + 1 instead of sz, thus end - 1 is the last index Make following code consistent

int cnt = merge_helper(sums, start, mid, lower, upper) 
        + merge_helper(sums, mid, end, lower, upper);

t(n) = 2t(n/2) + n, t(n) can be at most divided log n times, thus the time complexity is O(n * log n)

class Solution {
public:
    int countRangeSum(vector<int>& nums, int lower, int upper) {
        int sz = nums.size();
        vector<long> sums(sz + 1);
        for (int i = 0; i < sz; i++)
            sums[i + 1] = nums[i] + sums[i];
        return merge_helper(sums, 0, sz + 1, lower, upper);
    }
private:
    int merge_helper(vector<long>& sums, int start, int end, int lower, int upper) {
        if (start + 1 >= end)   return 0;
        int mid = start + (end - start) / 2;
        int cnt = merge_helper(sums, start, mid, lower, upper) 
                + merge_helper(sums, mid, end, lower, upper);
        
        int m = mid, n = mid;
        for (int i = start; i < mid; i++) { //O(n)
            while (m < end && sums[m] - sums[i] < lower)    m++;
            while (n < end && sums[n] - sums[i] <= upper)   n++;
            cnt += n - m;
        }
        //O(n) if given enough extra memory, otherwise O(n * log n)
        inplace_merge(sums.begin() + start, sums.begin() + mid, sums.begin() + end);
        return cnt;
    }
};
⚠️ **GitHub.com Fallback** ⚠️