Modified Binary Search - kjingers/Grokking-Notes GitHub Wiki

As we know, whenever we are given a sorted Array or LinkedList or Matrix, and we are asked to find a certain element, the best algorithm we can use is the Binary Search.

start = 0
end = len(arr) - 1
middle = start + (end - start) / 2

If key < arr[middle]:

end = middle - 1

If key > arr[middle]:

start = middle + 1

Order-agnostic Binary Search

def binary_search(arr, key):
  start = 0
  end = len(arr) - 1
  isAscending = arr[end] > arr[start]

  while end >= start:
    mid = start + (end - start) // 2

    if arr[mid] == key:
      return mid
    
    if isAscending:
      # In first half
      if key < arr[mid]:
        end = mid - 1
      # In second half
      else:
        start = mid + 1
    else:
      # In first half
      if key > arr[mid]:
        end = mid - 1
      # In second half
      else:
        start = mid + 1

  return -1


def main():
  print(binary_search([4, 6, 10], 10))
  print(binary_search([1, 2, 3, 4, 5, 6, 7], 5))
  print(binary_search([10, 6, 4], 10))
  print(binary_search([10, 6, 4], 4))


main()

Ceiling of a Number

Same idea as a normal binary search, except if the number doesn't exist in the array, then the index of start when the loop breaks will point to the next larger element (start = end + 1)

def search_ceiling_of_a_number(arr, key):

  # If key is greater than the largest element, return -1
  if key > arr[-1]:
    return -1

  start = 0
  end = len(arr) - 1


  while start <= end:
    mid = start + (end - start) // 2

    if arr[mid] == key:
      return mid

    # First Half
    if key < arr[mid]:
      end = mid - 1
    # Second Half
    else:
      start = mid + 1
  
  return start


def main():
  print(search_ceiling_of_a_number([4, 6, 10], 6))
  print(search_ceiling_of_a_number([1, 3, 8, 10, 15], 12))
  print(search_ceiling_of_a_number([4, 6, 10], 17))
  print(search_ceiling_of_a_number([4, 6, 10], -1))


main()

Floor in a Sorted Array

Same as above, except we need to check if key < arr[0] and also return arr[end] when the loop breaks.

def search_floor_of_a_number(arr, key):
  if key < arr[0]:  # if the 'key' is smaller than the smallest element
    return -1

  start, end = 0, len(arr) - 1
  while start <= end:
    mid = start + (end - start) // 2
    if key < arr[mid]:
      end = mid - 1
    elif key > arr[mid]:
      start = mid + 1
    else:  # found the key
      return mid

  # since the loop is running until 'start <= end', so at the end of the while loop, 'start == end+1'
  # we are not able to find the element in the given array, so the next smaller number will be arr[end]
  return end


def main():
  print(search_floor_of_a_number([4, 6, 10], 6))
  print(search_floor_of_a_number([1, 3, 8, 10, 15], 12))
  print(search_floor_of_a_number([4, 6, 10], 17))
  print(search_floor_of_a_number([4, 6, 10], -1))


main()

Find Smallest Letter Greater Than Target

My solution was a bit longer than real solution.

def search_next_letter(letters, key):

  if len(letters) < 1:
    return None

  # If greater than largest, than return arr[0]
  if key >= letters[-1]:
    return letters[0]

  start = 0
  end = len(letters) - 1

  while start <= end:
    mid = start + (end - start) // 2

    # Return letters[mid + 1]
    # Already accounted for rollover case before loop
    if key == letters[mid]:
        return letters[mid + 1]
    
    # First Half
    if key < letters[mid]:
      end = mid - 1
    else:
      start = mid + 1
  
  return letters[start]


def main():
  print(search_next_letter(['a', 'c', 'f', 'h'], 'f'))
  print(search_next_letter(['a', 'c', 'f', 'h'], 'b'))
  print(search_next_letter(['a', 'c', 'f', 'h'], 'm'))


main()

Solution provided

Really, just needed normal binary search. Don't need to check if we ever find match. Instead, change start even if equal. That way, start will end up as start = end + 1.

Second, just return[start % n] to account for wrap around.

def search_next_letter(letters, key):
  n = len(letters)

  start, end = 0, n - 1
  while start <= end:
    mid = start + (end - start) // 2
    if key < letters[mid]:
      end = mid - 1
    else: # key >= letters[mid]:
      start = mid + 1

  # since the loop is running until 'start <= end', so at the end of the while loop, 'start == end+1'
  return letters[start % n]


def main():
  print(search_next_letter(['a', 'c', 'f', 'h'], 'f'))
  print(search_next_letter(['a', 'c', 'f', 'h'], 'b'))
  print(search_next_letter(['a', 'c', 'f', 'h'], 'm'))


main()

Find First and Last Position of Element in Sorted Array

Very similar to a normal binary search. For this problem, we run a binary search twice to find the first and last occurrences, respectively. We do this by passing in a bool.

We will try to search for the ‘key’ in the given array; if the ‘key’ is found (i.e. key == arr[middle]) we have two options:

When trying to find the first position of the ‘key’, we can update end = middle - 1 to see if the key is present before middle. When trying to find the last position of the ‘key’, we can update start = middle + 1 to see if the key is present after middle.

In both cases, we will keep track of the last keyIndex found.

def find_range(arr, key):
  result = [-1, -1]

  # First find first occurance
  result[0] = binary_search(arr, key, False)

  # If key exists in arr, then find last occurance
  if result[0] > -1:
    result[1] = binary_search(arr, key, True)

  return result

# If findMax, then find last occurance. Else, first occurance
def binary_search(arr, key, findMax):

  start, end = 0, len(arr) - 1

  # Default k is -1 (not in arr)
  keyIndex = -1

  while start <= end:
    mid = start + (end - start) // 2

    # First Half
    if key < arr[mid]:
      end = mid - 1
    # Second Half
    elif key > arr[mid]:
      start = mid + 1
    # Else arr[mid] == key
    else:
      keyIndex = mid
      # If we want to find the last occurance
      if findMax:
        start = mid + 1
      # If we want to find the first occurance
      else:
        end = mid - 1
        
  return keyIndex




def main():
  print(find_range([4, 6, 6, 6, 9], 6))
  print(find_range([1, 3, 8, 10, 15], 10))
  print(find_range([1, 3, 8, 10, 15], 12))


main()

Search in a Sorted Infinite Array

Before a normal binary search, we need to determine the start and end bounds. We do this by beginning with start = 0, end = 1, and keep doubling until key < arr[end]. Then we perform binary search on these bounds.

import math


class ArrayReader:

  def __init__(self, arr):
    self.arr = arr

  def get(self, index):
    if index >= len(self.arr):
      return math.inf
    return self.arr[index]


def search_in_infinite_array(reader, key):
  
  # First, find range to search in
  # start with 2 and keep doubling size
  start, end = 0, 1
  while key > reader.get(end):
    newStart = end + 1
    end += (end - start + 1) * 2
    start = newStart

  return binary_search(reader, key, start, end)


def binary_search(reader, key, start, end):

  while start <= end:
    mid = start + (end - start) // 2
    if key < reader.get(mid):
      end = mid - 1
    elif key > reader.get(mid):
      start = mid + 1
    else: # reader.get(mid) == key
      return mid

  return -1


def main():
  reader = ArrayReader([4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30])
  print(search_in_infinite_array(reader, 16))
  print(search_in_infinite_array(reader, 11))
  reader = ArrayReader([1, 3, 8, 10, 15])
  print(search_in_infinite_array(reader, 15))
  print(search_in_infinite_array(reader, 200))


main()

Minimum Difference Element

Pretty similar to ceiling and floor element.

First, check if key is greater than max or smaller than min. Then, search for key. If found, return key. Else, we break look where arr[start] has the next greater element and arr[end] has the next smaller element. Compare the difference between each with key to return the minimum difference element.

def search_min_diff_element(arr, key):
  
  if key < arr[0]:
    return arr[0]
  if key > arr[-1]:
    return arr[-1]

  start, end = 0, len(arr) - 1

  while start <= end:

    mid = start + (end - start) // 2

    if key < arr[mid]:
      end = mid - 1
    elif key > arr[mid]:
      start = mid + 1
    else:
      return arr[mid]

  # Here means key is not in arr
  # Since start = end + 1, compare arr[start] and arr[end]
  # arr[start] is the closest larger number
  # arr[end] is the closest smaller number
  if (arr[start] - key) < (key - arr[end]):
    return arr[start]

  return arr[end]

Bitonic (Mountain) Array Maximum

'''
Pretty much a normal Binary Search.
Except a key, we narrow min down to the max element

1. If mid is in the decreasing portion, then set end = mid
2. If mid is in the increasing portion, then set start = mid + 1

When loop ends, start == end, both index of max element
'''

def find_max_in_bitonic_array(arr):
  
  start, end = 0, len(arr) - 1

  while start < end:
    mid = start + (end - start) // 2
    # Max index is <= mid
    if arr[mid] > arr[mid + 1]:
      end = mid
    # Max index is > mid
    else:
      start = mid + 1

  # When We break from loop, start == end
  # Both are index of max element
  return arr[start]



def main():
  print(find_max_in_bitonic_array([1, 3, 8, 12, 4, 2]))
  print(find_max_in_bitonic_array([3, 8, 3, 1]))
  print(find_max_in_bitonic_array([1, 3, 8, 12]))
  print(find_max_in_bitonic_array([10, 9, 8]))


main()

Search Bitonic Array

Three steps:

  1. Find the max element like above. Split array into ascending portion and descending portion
  2. Binary search Ascending Order
  3. If not found in 2), search Descending portion.
def search_bitonic_array(arr, key):
  
  keyIndex = -1
  # First, Find Max of Bitonic Array
  maxIndex = find_max(arr)

  # Search 0 to maxIndex in Ascending order
  keyIndex = binary_search(arr, 0, maxIndex, key)

  # If not found above, search maxIndex + 1 to end in Descending Order
  if keyIndex == -1:
    keyIndex = binary_search(arr, maxIndex + 1, len(arr) - 1, key)

  return keyIndex

# Returns the index of the max element of the bitonic array
def find_max(arr):
  start, end = 0, len(arr) - 1

  while start < end:
    mid = start + (end - start) // 2

    # Descending Order
    if arr[mid] > arr[mid + 1]:
      end = mid
    # Ascending Order
    else:
      start = mid + 1

  return start

def binary_search(arr, start, end, key):
  
  while start <= end:
    mid = start + (end - start) // 2

    # If Ascending
    if arr[end] > arr[0]:
      if arr[mid] > key: # First Half
        end = mid - 1
      elif arr[mid] < key: # Second Half
        start = mid + 1
      else: # arr[mid] == key
        return mid
    else: # Descending Order
      if arr[mid] < key: # First Half
        end = mid - 1
      elif arr[mid] > key: # Second Half
        start = mid + 1
      else: # arr[mid] == key
        return mid

  return -1




def main():
  print(search_bitonic_array([1, 3, 8, 4, 3], 4))
  print(search_bitonic_array([3, 8, 3, 1], 8))
  print(search_bitonic_array([1, 3, 8, 12], 12))
  print(search_bitonic_array([10, 9, 8], 10))


main()

Search in Rotated Sorted Array

Start off like normal. Calculate mid.

Now check is start--> is acending.

  • If so, check if key is in range. If it is, adjust end. If not, adjust start to mid + 1
  • If not ascending, then we know mid-->end is ascending. Check if key is in this range. If so, adjust start to mid + 1. If not, adjust end.
def search_rotated_array(arr, key):
  start, end = 0, len(arr) - 1
  
  # Start off like Normal, calculate Mid
  while start <= end:
    mid = start + (end - start) // 2

    if key == arr[mid]:
      return mid

    # If arr[mid] >= arr[start], then we know start-->mid is 
    # sorted in Ascending order.
    if arr[mid] >= arr[start]:
      
      # Now we check if key is in this range.
      if key >= arr[start] and key < arr[mid]:

        # If in this range, adjust end
        end = mid - 1

      # Else, adjust start, since we know key is in second half
      else:
        start = mid + 1

    else: # Else, mid-->end in ascending order
      
      if key > arr[mid] and key <= arr[end]:
        start = mid + 1
      
      else:
        end = mid - 1

  return -1


def main():
  print(search_rotated_array([10, 15, 1, 3, 8], 15))
  print(search_rotated_array([4, 5, 7, 9, 10, -1, 2], 10))


main()

Search in Rotated Sorted Array (w/ Duplicates)

Same thing, except have to check for the case when arr[start] == arr[mid] == arr[end] in this case, we know arr[mid] != key since we check that. So we can move start and end by 1: start +=1 and end -= 1.

This will make worst-case time complexity O(n)

def search_rotated_with_duplicates(arr, key):
  start, end = 0, len(arr) - 1
  while start <= end:
    mid = start + (end - start) // 2
    if arr[mid] == key:
      return mid

    # the only difference from the previous solution,
    # if numbers at indexes start, mid, and end are same, we can't choose a side
    # the best we can do, is to skip one number from both ends as key != arr[mid]
    if arr[start] == arr[mid] and arr[end] == arr[mid]:
      start += 1
      end -= 1
    elif arr[start] <= arr[mid]:  # left side is sorted in ascending order
      if key >= arr[start] and key < arr[mid]:
        end = mid - 1
      else:  # key > arr[mid]
        start = mid + 1

    else:  # right side is sorted in ascending order
      if key > arr[mid] and key <= arr[end]:
        start = mid + 1
      else:
        end = mid - 1

  # we are not able to find the element in the given array
  return -1


def main():
  print(search_rotated_with_duplicates([3, 7, 3, 3, 3], 7))


main()

Count Rotations

'''
We want to find the pivot point.
'''


def count_rotations(arr):
  
  # Start Binary Search as Normal
  start, end = 0, len(arr) - 1

  while start < end:
    mid = start + (end - start) // 2

    # Compare arr[mid] with numbers on both sides
    # Use mid < end and mid > start to avoid comparing beyond index
    # of current search range
    if mid < end and arr[mid] > arr[mid + 1]:
      return mid + 1
    elif mid > start and arr[mid] < arr[mid - 1]:
      return mid

    # Now, we need to see which side is sorted
    # Pivot point will be on the non-sorted side
    if arr[mid] >= arr[start]:
      start = mid + 1
    else:
      end = mid - 1

  # If we get to here, then no pivot point found
  return 0


def main():
  print(count_rotations([10, 15, 1, 3, 8]))
  print(count_rotations([4, 5, 7, 9, 10, -1, 2]))
  print(count_rotations([1, 3, 8, 10]))


main()

Count Rotations (w/ Duplicates)

Similar to above. But have to check the case when arr[start] == arr[mid] == arr[end], because we can't tell which half is sorted.

Like before we want to increment by 1 and decrement end by 1. But before doing that, we must check to see if start + 1 or end - 1 are the pivot point. We don't want to exclude the pivot point when we do start +=1 or end -= 1. Above, when just searching for key, we know that arr[mid] != key, and since arr[start] == arr[mid] == arr[end], we could safely skip arr[start] and arr[end], since we already knew they did not contain key.

def count_rotations_with_duplicates(arr):
  start, end = 0, len(arr) - 1
  while start < end:
    mid = start + (end - start) // 2
    # if element at mid is greater than the next element
    if mid < end and arr[mid] > arr[mid + 1]:
      return mid + 1
    # if element at mid is smaller than the previous element
    if mid > start and arr[mid - 1] > arr[mid]:
      return mid

    # this is the only difference from the previous solution
    # if numbers at indices start, mid, and end are same, we can't choose a side
    # the best we can do is to skip one number from both ends if they are not the smallest number
    if arr[start] == arr[mid] and arr[end] == arr[mid]:
      if arr[start] > arr[start + 1]:  # if element at start+1 is not the smallest
        return start + 1
      start += 1
      if arr[end - 1] > arr[end]:  # if the element at end is not the smallest
        return end
      end -= 1
    # left side is sorted, so the pivot is on right side
    elif arr[start] < arr[mid] or (arr[start] == arr[mid] and arr[mid] > arr[end]):
      start = mid + 1
    else:  # right side is sorted, so the pivot is on the left side
      end = mid - 1

  return 0  # the array has not been rotated


def main():
  print(count_rotations_with_duplicates([3, 3, 7, 3]))


main()