4. Median of Two Sorted Arrays - cocoder39/coco39_LC GitHub Wiki

4. Median of Two Sorted Arrays

Notes 2022

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        m, n = len(nums1), len(nums2)
        if m > n: # without loss of generality
            return self.findMedianSortedArrays(nums2, nums1)
        
        left1, left2 = None, None # sizes of nums1.left_partition and nums2.left_partition 
        start, end = 0, m # 0 <= left1 <= m
        while start <= end: # terminate when narrowing to single element
            left1 = start + (end - start) // 2
            left2 = (m+n+1) // 2 - left1 # left1+left2 == (m+n+1)//2 and 0 <= left2 <= n 
            if left1 >= 1 and left2 < n and nums1[left1-1] > nums2[left2]: # nums1.left_partition.max > nums2.right_partition.min
                end = left1 - 1 # nums1.left_partition.max is too large so narrowing nums1.left_partition 
            elif left1 < m and left2 >= 1 and nums2[left2-1] > nums1[left1]: #nums1.right_partition.min is too small so expanding nums1.left_partition 
                start = left1 + 1
            else:
                # left1 == 0 or left2 == n or left1 == m or left2 == 0
                break
        
        left_max = None
        if left1 == 0:
            left_max = nums2[left2-1]
        elif left2 == 0:
            left_max = nums1[left1-1]
        else:
            left_max = max(nums1[left1-1], nums2[left2-1])
        
        if (m + n) % 2: # odd
            return left_max
        
        right_min = None
        if left1 == m:
            right_min = nums2[left2]
        elif left2 == n:
            right_min = nums1[left1]
        else:
            right_min = min(nums1[left1], nums2[left2])
        return (left_max + right_min) / 2 # midian so / instead of //
left2 = (m+n+1) // 2 - left1 and -m <= -left1 <= 0
so left2 <= (m+n+1) // 2 <= (2n+1)//2 = n and left2 >= (m+n+1) // 2 - m >= (n-m+1) // 2 = 0

=====================================================================

Notes 2021:

  1. let m <= n to simplify the discussion without lose of generality
if m > n:
            return self.findMedianSortedArrays(nums2, nums1)
  1. 0 <= right1 <= m => 0 <= right2 <= n

  2. use low <= high to terminate binary search . low+1 < high will make it difficult to discuss

  3. right2 = (m + n + 1) // 2 - right1 where +1 is to provide a consistent approach to deal with odd sum

solution 1:

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        m, n = len(nums1), len(nums2)
        if m > n:
            return self.findMedianSortedArrays(nums2, nums1)
        
        low, high = 0, m
        right1, right2 = None, None
        while low <= high:
            # index of 1st element at right side
            right1 = low + (high - low) // 2 
            right2 = (m + n + 1) // 2 - right1
            
            # 0 <= right1 <= m and 0 <= right2 <= n
            
            if 1 <= right1 and right2 < n and nums1[right1-1] > nums2[right2]:
                high = right1 - 1
            elif right1 < m and 1 <= right2 and nums1[right1] < nums2[right2-1]:
                low = right1 + 1
            else:
                # right1 == 0 or right2 == n or right1 == m or right2 == 0
                break
            
        leftMax = None
        if right1 == 0:
            leftMax = nums2[right2-1]
        elif right2 == 0:
            leftMax = nums1[right1-1]
        else:
            leftMax = max(nums1[right1-1], nums2[right2-1])
        
        if (m + n) % 2:
            return leftMax
        
        rightMin = None
        if right1 == m:
            rightMin = nums2[right2]
        elif right2 == n:
            rightMin = nums1[right1]
        else:
            rightMin = min(nums1[right1], nums2[right2])
        
        return (leftMax + rightMin) / 2

Option 2

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        m, n = len(nums1), len(nums2)
        
        def helper(start1, start2, k):
            if start1 >= m:
                return nums2[start2+k-1]
            if start2 >= n:
                return nums1[start1+k-1]
            if k == 1:
                return min(nums1[start1], nums2[start2])
            
            mid1, mid2 = float("inf"), float("inf")
            if start1 + k//2 - 1 < m:
                mid1 = nums1[start1 + k//2 - 1]
            if start2 + k//2 - 1 < n:
                mid2 = nums2[start2 + k//2 - 1]
            
            if mid1 < mid2:
                return helper(start1+k//2, start2, k-k//2)
            else:
                return helper(start1, start2+k//2, k-k//2)
            
        return (helper(0, 0, (m+n+1) // 2) + helper(0, 0, (m+n+2) // 2)) / 2
            

===================================================================================================

tips: time is O(log k) = O(log(m + n)), stack space is also O(log (m + n)) clarification:

  • 1 there exists at least one element at each input array if we are asked to return median of both. you can also check with interviewer how to handle empty array
  • 2 based on 1, the problem is transferred to return the kth element among two arrays, where k = (m+n)/2 ((not exactly, consider even and odd in real implementation)). Thus k >= (1+1)/2 = 1
  • 3 based on 2, we may have code below. Several places need take care

even if (k/2)th of array1 and (k/2)th of array2 are equal, we cannot guarantee that they are the kth element among two arrays since k == 2*(k/2) or k == 2*(k/2) + 1. we can only guarantee where there is no kth element and exclude them in next iteration

use INT_MAX to fill trail places, which would not impact the kth element, and make the comparison easier

ensure k/2 - 1 >= 0 before using idx

    1. based on 3, we need guarantee k/2 >= 1, while we have guaranteed k>=1 in (1). Thus handle k == 1. but we need check start1 and start2 before that
if (k == 1) {
            return min(nums1[start1], nums2[start2]); //ensure start1 and start2 before using them
        }

Finally we have code below.

int idx1 = start1 + (k/2 - 1); //kth's index is k-1
        int idx2 = start2 + (k/2 - 1); //ensure k/2 >= 1 before using idx
        int mid1 = idx1 < nums1.size() ? nums1[idx1] : INT_MAX;
        int mid2 = idx2 < nums2.size() ? nums2[idx2] : INT_MAX;
        /* 
        if (mid1 < mid2) {
            return kth(nums1, idx1 + 1, nums2, start2, k - k/2);
        } else if (mid1 > mid2) {
            return kth(nums1, start1, nums2, idx2 + 1, k - k/2);
        } else { 
            //bug implementation since k/2 + k/2 <= k, cannot guarantee mid1 is the kth element
            //can ensure where there is no target
            return mid1;
        }
        */
        if (mid1 < mid2) {
            return kth(nums1, idx1 + 1, nums2, start2, k - k/2);
        } else { 
            //ignore mid1 regardless that k == 2*(k/2) or k == 2*(k/2) + 1
            //even if k == 2*(k/2), mid2 can be the kth element
            return kth(nums1, start1, nums2, idx2 + 1, k - k/2);
        } 
  • terminal case is k == 1, where returns min(nums1[start1], nums2[start2])
  • if mid is out of bound, assign mid with INT_MAX, which can be viewed as dummy number, guarantees the mid would not be the kth.
class Solution {
public:
   double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
       int m = nums1.size(), n = nums2.size();
       if ((m + n) % 2 == 0) { //even
           return (kth(nums1, 0, nums2, 0, (m+n)/2) + kth(nums1, 0, nums2, 0, (m+n)/2 + 1)) / 2.0;
       } else { //odd
           return kth(nums1, 0, nums2, 0, (m+n+1)/2);
       }
   }
private:
   //guarantee input k >= 1 since there are at least one element in each array 
   int kth(vector<int>& nums1, int start1, vector<int>& nums2, int start2, int k) {
       if (start1 >= nums1.size()) {
           return nums2[start2 + (k - 1)];
       } else if (start2 >= nums2.size()) {
           return nums1[start1 + (k - 1)];
       } else if (k == 1) {
           return min(nums1[start1], nums2[start2]); //ensure start1 and start2 before using them
       }
       
       int idx1 = start1 + (k/2 - 1); //kth's index is k-1
       int idx2 = start2 + (k/2 - 1); //ensure k/2 >= 1 before using idx
       int mid1 = idx1 < nums1.size() ? nums1[idx1] : INT_MAX;
       int mid2 = idx2 < nums2.size() ? nums2[idx2] : INT_MAX;
       /* 
       if (mid1 < mid2) {
           return kth(nums1, idx1 + 1, nums2, start2, k - k/2);
       } else if (mid1 > mid2) {
           return kth(nums1, start1, nums2, idx2 + 1, k - k/2);
       } else { 
           //bug implementation since k/2 + k/2 <= k, cannot guarantee mid1 is the kth element
           //can ensure where there is no target
           return mid1;
       }
       */
       if (mid1 < mid2) {
           return kth(nums1, idx1 + 1, nums2, start2, k - k/2);
       } else { 
           //ignore mid1 regardless that k == 2*(k/2) or k == 2*(k/2) + 1
           //even if k == 2*(k/2), mid2 can be the kth element
           return kth(nums1, start1, nums2, idx2 + 1, k - k/2);
       } 
   } 
};

solution 2: time is O(log min(m, n))

divide both arrays into two parts
 left_part               |        right_part
A[0], A[1], ..., A[i-1]  |  A[i], A[i+1], ..., A[m-1]
B[0], B[1], ..., B[j-1]  |  B[j], B[j+1], ..., B[n-1]

we want a bi-partition such that
1) len(left_part) == len(right_part) (+ 1)
2) max(left_part) <= min(right_part)
where len(left_part) is belong to [0, m]

in other words:
(1) i + j == (m - i) + (n - j) (+ 1)
    if n >= m, we set i = 0 ~ m, then j = (m + n + 1)/2 - i to achieve bi-partition
(2) B[j-1] <= A[i] and A[i-1] <= B[j]

further
Searching i in [0, m], to find an `i` such that:
    B[j-1] <= A[i] and A[i-1] <= B[j], ( where i + j = (m + n + 1)/2)
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int m = nums1.size(), n = nums2.size();
        if (m > n) { //make time complexity be O(log min(m, n))
            return findMedianSortedArrays(nums2, nums1);
        }
        
        int left1, left2; //number of elements in nums1 and nums2 that at left partition
        int start = 0, end = m; //searching range of left1 is [0, m]
        while (start <= end) {
            left1 = start + (end - start) / 2; 
            //bi-partition left and right sides, left1 + left2 = (m + n + 1) / 2
            //if m+n is odd, return left_max, else return left_max + right_min
            left2 = (m + n + 1) / 2 - left1; 
            
            if (left1 >= 1 && left2 < n && nums1[left1 - 1] > nums2[left2]) {
                end = left1 - 1;
            } else if (left2 >= 1 && left1 < m && nums2[left2 - 1] > nums1[left1]) {
                start = left1 + 1;
            } else { //left1 == 0 || left2 == n || left2 == 0 || left1 == m || 
                     //(nums1[left1 - 1] <= nums2[left2] && nums2[left2 - 1] <= nums1[left1]
                     //all corner cases can meet the bi-partition requirement  
                break;
            }
        }
        
        int left_max;
        if (left1 == 0) {
            left_max = nums2[left2 - 1];
        } else if (left2 == 0) {
            left_max = nums1[left1 - 1];
        } else {
            left_max = max(nums1[left1 - 1], nums2[left2 - 1]);
        }
        if ((m + n) % 2) { //odd
            return left_max;
        }
        
        int right_min;
        if (left1 == m) {
            right_min = nums2[left2];
        } else if (left2 == n) {
            right_min = nums1[left1];
        } else {
            right_min = min(nums1[left1], nums2[left2]);
        }
        return (left_max + right_min) / 2.0;
    }

it can be used to solve "find kth element" by forcing left1 + left2 = k. same idea is used in 295. Find Median from Data Stream

⚠️ **GitHub.com Fallback** ⚠️