4. Median of Two Sorted Arrays - cocoder39/coco39_LC GitHub Wiki
4. Median of Two Sorted Arrays
Notes 2022
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
m, n = len(nums1), len(nums2)
if m > n: # without loss of generality
return self.findMedianSortedArrays(nums2, nums1)
left1, left2 = None, None # sizes of nums1.left_partition and nums2.left_partition
start, end = 0, m # 0 <= left1 <= m
while start <= end: # terminate when narrowing to single element
left1 = start + (end - start) // 2
left2 = (m+n+1) // 2 - left1 # left1+left2 == (m+n+1)//2 and 0 <= left2 <= n
if left1 >= 1 and left2 < n and nums1[left1-1] > nums2[left2]: # nums1.left_partition.max > nums2.right_partition.min
end = left1 - 1 # nums1.left_partition.max is too large so narrowing nums1.left_partition
elif left1 < m and left2 >= 1 and nums2[left2-1] > nums1[left1]: #nums1.right_partition.min is too small so expanding nums1.left_partition
start = left1 + 1
else:
# left1 == 0 or left2 == n or left1 == m or left2 == 0
break
left_max = None
if left1 == 0:
left_max = nums2[left2-1]
elif left2 == 0:
left_max = nums1[left1-1]
else:
left_max = max(nums1[left1-1], nums2[left2-1])
if (m + n) % 2: # odd
return left_max
right_min = None
if left1 == m:
right_min = nums2[left2]
elif left2 == n:
right_min = nums1[left1]
else:
right_min = min(nums1[left1], nums2[left2])
return (left_max + right_min) / 2 # midian so / instead of //
left2 = (m+n+1) // 2 - left1 and -m <= -left1 <= 0
so left2 <= (m+n+1) // 2 <= (2n+1)//2 = n and left2 >= (m+n+1) // 2 - m >= (n-m+1) // 2 = 0
=====================================================================
Notes 2021:
- let m <= n to simplify the discussion without lose of generality
if m > n:
return self.findMedianSortedArrays(nums2, nums1)
-
0 <= right1 <= m => 0 <= right2 <= n
-
use
low <= high
to terminate binary search .low+1 < high
will make it difficult to discuss -
right2 = (m + n + 1) // 2 - right1
where+1
is to provide a consistent approach to deal with odd sum
solution 1:
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
m, n = len(nums1), len(nums2)
if m > n:
return self.findMedianSortedArrays(nums2, nums1)
low, high = 0, m
right1, right2 = None, None
while low <= high:
# index of 1st element at right side
right1 = low + (high - low) // 2
right2 = (m + n + 1) // 2 - right1
# 0 <= right1 <= m and 0 <= right2 <= n
if 1 <= right1 and right2 < n and nums1[right1-1] > nums2[right2]:
high = right1 - 1
elif right1 < m and 1 <= right2 and nums1[right1] < nums2[right2-1]:
low = right1 + 1
else:
# right1 == 0 or right2 == n or right1 == m or right2 == 0
break
leftMax = None
if right1 == 0:
leftMax = nums2[right2-1]
elif right2 == 0:
leftMax = nums1[right1-1]
else:
leftMax = max(nums1[right1-1], nums2[right2-1])
if (m + n) % 2:
return leftMax
rightMin = None
if right1 == m:
rightMin = nums2[right2]
elif right2 == n:
rightMin = nums1[right1]
else:
rightMin = min(nums1[right1], nums2[right2])
return (leftMax + rightMin) / 2
Option 2
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
m, n = len(nums1), len(nums2)
def helper(start1, start2, k):
if start1 >= m:
return nums2[start2+k-1]
if start2 >= n:
return nums1[start1+k-1]
if k == 1:
return min(nums1[start1], nums2[start2])
mid1, mid2 = float("inf"), float("inf")
if start1 + k//2 - 1 < m:
mid1 = nums1[start1 + k//2 - 1]
if start2 + k//2 - 1 < n:
mid2 = nums2[start2 + k//2 - 1]
if mid1 < mid2:
return helper(start1+k//2, start2, k-k//2)
else:
return helper(start1, start2+k//2, k-k//2)
return (helper(0, 0, (m+n+1) // 2) + helper(0, 0, (m+n+2) // 2)) / 2
===================================================================================================
tips: time is O(log k) = O(log(m + n)), stack space is also O(log (m + n)) clarification:
- 1 there exists at least one element at each input array if we are asked to return median of both. you can also check with interviewer how to handle empty array
- 2 based on 1, the problem is transferred to return the kth element among two arrays, where k = (m+n)/2 ((not exactly, consider even and odd in real implementation)). Thus k >= (1+1)/2 = 1
- 3 based on 2, we may have code below. Several places need take care
even if (k/2)th of array1 and (k/2)th of array2 are equal, we cannot guarantee that they are the kth element among two arrays since k == 2*(k/2) or k == 2*(k/2) + 1. we can only guarantee where there is no kth element and exclude them in next iteration
use INT_MAX to fill trail places, which would not impact the kth element, and make the comparison easier
ensure k/2 - 1 >= 0 before using idx
-
- based on 3, we need guarantee k/2 >= 1, while we have guaranteed k>=1 in (1). Thus handle k == 1. but we need check start1 and start2 before that
if (k == 1) {
return min(nums1[start1], nums2[start2]); //ensure start1 and start2 before using them
}
Finally we have code below.
int idx1 = start1 + (k/2 - 1); //kth's index is k-1
int idx2 = start2 + (k/2 - 1); //ensure k/2 >= 1 before using idx
int mid1 = idx1 < nums1.size() ? nums1[idx1] : INT_MAX;
int mid2 = idx2 < nums2.size() ? nums2[idx2] : INT_MAX;
/*
if (mid1 < mid2) {
return kth(nums1, idx1 + 1, nums2, start2, k - k/2);
} else if (mid1 > mid2) {
return kth(nums1, start1, nums2, idx2 + 1, k - k/2);
} else {
//bug implementation since k/2 + k/2 <= k, cannot guarantee mid1 is the kth element
//can ensure where there is no target
return mid1;
}
*/
if (mid1 < mid2) {
return kth(nums1, idx1 + 1, nums2, start2, k - k/2);
} else {
//ignore mid1 regardless that k == 2*(k/2) or k == 2*(k/2) + 1
//even if k == 2*(k/2), mid2 can be the kth element
return kth(nums1, start1, nums2, idx2 + 1, k - k/2);
}
- terminal case is
k == 1
, where returnsmin(nums1[start1], nums2[start2])
- if mid is out of bound, assign mid with
INT_MAX
, which can be viewed as dummy number, guarantees the mid would not be the kth.
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size(), n = nums2.size();
if ((m + n) % 2 == 0) { //even
return (kth(nums1, 0, nums2, 0, (m+n)/2) + kth(nums1, 0, nums2, 0, (m+n)/2 + 1)) / 2.0;
} else { //odd
return kth(nums1, 0, nums2, 0, (m+n+1)/2);
}
}
private:
//guarantee input k >= 1 since there are at least one element in each array
int kth(vector<int>& nums1, int start1, vector<int>& nums2, int start2, int k) {
if (start1 >= nums1.size()) {
return nums2[start2 + (k - 1)];
} else if (start2 >= nums2.size()) {
return nums1[start1 + (k - 1)];
} else if (k == 1) {
return min(nums1[start1], nums2[start2]); //ensure start1 and start2 before using them
}
int idx1 = start1 + (k/2 - 1); //kth's index is k-1
int idx2 = start2 + (k/2 - 1); //ensure k/2 >= 1 before using idx
int mid1 = idx1 < nums1.size() ? nums1[idx1] : INT_MAX;
int mid2 = idx2 < nums2.size() ? nums2[idx2] : INT_MAX;
/*
if (mid1 < mid2) {
return kth(nums1, idx1 + 1, nums2, start2, k - k/2);
} else if (mid1 > mid2) {
return kth(nums1, start1, nums2, idx2 + 1, k - k/2);
} else {
//bug implementation since k/2 + k/2 <= k, cannot guarantee mid1 is the kth element
//can ensure where there is no target
return mid1;
}
*/
if (mid1 < mid2) {
return kth(nums1, idx1 + 1, nums2, start2, k - k/2);
} else {
//ignore mid1 regardless that k == 2*(k/2) or k == 2*(k/2) + 1
//even if k == 2*(k/2), mid2 can be the kth element
return kth(nums1, start1, nums2, idx2 + 1, k - k/2);
}
}
};
solution 2: time is O(log min(m, n))
divide both arrays into two parts
left_part | right_part
A[0], A[1], ..., A[i-1] | A[i], A[i+1], ..., A[m-1]
B[0], B[1], ..., B[j-1] | B[j], B[j+1], ..., B[n-1]
we want a bi-partition such that
1) len(left_part) == len(right_part) (+ 1)
2) max(left_part) <= min(right_part)
where len(left_part) is belong to [0, m]
in other words:
(1) i + j == (m - i) + (n - j) (+ 1)
if n >= m, we set i = 0 ~ m, then j = (m + n + 1)/2 - i to achieve bi-partition
(2) B[j-1] <= A[i] and A[i-1] <= B[j]
further
Searching i in [0, m], to find an `i` such that:
B[j-1] <= A[i] and A[i-1] <= B[j], ( where i + j = (m + n + 1)/2)
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size(), n = nums2.size();
if (m > n) { //make time complexity be O(log min(m, n))
return findMedianSortedArrays(nums2, nums1);
}
int left1, left2; //number of elements in nums1 and nums2 that at left partition
int start = 0, end = m; //searching range of left1 is [0, m]
while (start <= end) {
left1 = start + (end - start) / 2;
//bi-partition left and right sides, left1 + left2 = (m + n + 1) / 2
//if m+n is odd, return left_max, else return left_max + right_min
left2 = (m + n + 1) / 2 - left1;
if (left1 >= 1 && left2 < n && nums1[left1 - 1] > nums2[left2]) {
end = left1 - 1;
} else if (left2 >= 1 && left1 < m && nums2[left2 - 1] > nums1[left1]) {
start = left1 + 1;
} else { //left1 == 0 || left2 == n || left2 == 0 || left1 == m ||
//(nums1[left1 - 1] <= nums2[left2] && nums2[left2 - 1] <= nums1[left1]
//all corner cases can meet the bi-partition requirement
break;
}
}
int left_max;
if (left1 == 0) {
left_max = nums2[left2 - 1];
} else if (left2 == 0) {
left_max = nums1[left1 - 1];
} else {
left_max = max(nums1[left1 - 1], nums2[left2 - 1]);
}
if ((m + n) % 2) { //odd
return left_max;
}
int right_min;
if (left1 == m) {
right_min = nums2[left2];
} else if (left2 == n) {
right_min = nums1[left1];
} else {
right_min = min(nums1[left1], nums2[left2]);
}
return (left_max + right_min) / 2.0;
}
it can be used to solve "find kth element" by forcing left1 + left2 = k. same idea is used in 295. Find Median from Data Stream