327. Count of Range Sum - cocoder39/coco39_LC GitHub Wiki
Notes 2020:
Similar to count smaller than self
BIT
class BinaryIndexTree:
def __init__(self, n):
self.n = n
self.arr = [0] * (n+1)
def getLeastSignificantBit(self, i: int) -> int:
return i & (-i)
def update(self, index: int) -> None:
while index <= self.n:
self.arr[index] += 1
index += self.getLeastSignificantBit(index)
def query(self, index: int) -> int:
total = 0
while index > 0:
total += self.arr[index]
index -= self.getLeastSignificantBit(index)
return total
class Solution:
def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
if not nums:
return 0
sums = [0] * len(nums)
total = 0
count = 0
for i, num in enumerate(nums):
total += num
if lower <= total <= upper:
count += 1
sums[i] = total
sortedSums = sorted(list(set(sums)))
bit = BinaryIndexTree(len(sortedSums))
for num in sums:
higherRank = self.getHigherRank(sortedSums, num - lower)
if higherRank != -1:
lowerRank = self.getLowerRank(sortedSums, num - upper)
if lowerRank != -1:
count += bit.query(higherRank+1) - bit.query(lowerRank)
rank = self.getRank(sortedSums, num)
bit.update(rank+1)
return count
# get highest index of elements that are <= target
def getHigherRank(self, sortedArr: List[int], target: int):
low, high = 0, len(sortedArr) - 1
while low + 1 < high:
mid = low + (high-low) // 2
if sortedArr[mid] <= target:
low = mid
else:
high = mid
if sortedArr[high] <= target:
return high
if sortedArr[low] <= target:
return low
return -1
# get lowest index of elements that are >= target
def getLowerRank(self, sortedArr: List[int], target: int):
low, high = 0, len(sortedArr) - 1
while low + 1 < high:
mid = low + (high-low) // 2
if sortedArr[mid] >= target:
high = mid
else:
low = mid
if sortedArr[low] >= target:
return low
if sortedArr[high] >= target:
return high
return -1
# get rank of target
def getRank(self, sortedArr: List[int], target: int):
low, high = 0, len(sortedArr) - 1
while low + 1 < high:
mid = low + (high-low) // 2
if sortedArr[mid] > target:
high = mid
elif sortedArr[mid] < target:
low = mid
else:
return mid
if sortedArr[low] == target:
return low
return high
Merge sort
class Solution:
def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
if not nums:
return 0
n = len(nums)
sum_ = 0
sums = [0] * n
count = 0
for i, num in enumerate(nums):
sum_ += num
if lower <= sum_ <= upper:
count += 1
sums[i] = sum_
return count + self.mergeSort(sums, 0, n-1, lower, upper)
def mergeSort(self, sums: List[int], low: int, high: int, lowerBound: int, upperBound: int) -> int:
if low >= high:
return 0
mid = low + (high - low) // 2
left = self.mergeSort(sums, low, mid, lowerBound, upperBound)
right = self.mergeSort(sums, mid+1, high, lowerBound, upperBound)
count = self.merge(sums, low, mid, mid+1, high, lowerBound, upperBound)
return left + right + count
def merge(self, sums: List[int],
left1: int, right1: int,
left2: int, right2: int,
low: int, high: int) -> int:
count = 0
a = b = left2
for i in range(left1, right1+1):
while a <= right2 and sums[a] - sums[i] < low:
a += 1
while b <= right2 and sums[b] - sums[i] <= high:
b += 1
count += b - a
sorted_arr = []
p1, p2 = left1, left2
while p1 <= right1 and p2 <= right2:
if sums[p1] <= sums[p2]:
sorted_arr.append(sums[p1])
p1 += 1
else:
sorted_arr.append(sums[p2])
p2 += 1
if p1 > right1:
for i in range(p2, right2+1):
sorted_arr.append(sums[i])
if p2 > right2:
for i in range(p1, right1+1):
sorted_arr.append(sums[i])
sums[left1: right2+1] = sorted_arr[:]
return count
======================================================================
compare this problem with count smaller than self, binary index tree can also solve this problem, and mergesort can be used to solve that problem as well.
tips:
passing to end
is sz + 1
instead of sz
, thus end - 1
is the last index Make following code consistent
int cnt = merge_helper(sums, start, mid, lower, upper)
+ merge_helper(sums, mid, end, lower, upper);
t(n) = 2t(n/2) + n
, t(n)
can be at most divided log n
times, thus the time complexity is O(n * log n)
class Solution {
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
int sz = nums.size();
vector<long> sums(sz + 1);
for (int i = 0; i < sz; i++)
sums[i + 1] = nums[i] + sums[i];
return merge_helper(sums, 0, sz + 1, lower, upper);
}
private:
int merge_helper(vector<long>& sums, int start, int end, int lower, int upper) {
if (start + 1 >= end) return 0;
int mid = start + (end - start) / 2;
int cnt = merge_helper(sums, start, mid, lower, upper)
+ merge_helper(sums, mid, end, lower, upper);
int m = mid, n = mid;
for (int i = start; i < mid; i++) { //O(n)
while (m < end && sums[m] - sums[i] < lower) m++;
while (n < end && sums[n] - sums[i] <= upper) n++;
cnt += n - m;
}
//O(n) if given enough extra memory, otherwise O(n * log n)
inplace_merge(sums.begin() + start, sums.begin() + mid, sums.begin() + end);
return cnt;
}
};