二分查找 - acgtyrant/Algorithm-and-Data-Structure GitHub Wiki

从二分查找谈起

重点是(误):

90% 以上的程序员无法正确无误的写出二分查找代码。

先前,我花了一小时多速码出二分查找的 C++ 实现,一次性测试通过,我击败了 90% 以上的程序员!

首先要看看 STL 中 binary_search 的接口是什么样子的:

template< class ForwardIt, class T >
bool binary_search( ForwardIt first, ForwardIt last, const T& value );

std::find 不同,返回值类型只是 bool, 但形参就很棘手了,它们都是迭代器!而且根据 Reference 所描述,该函数只搜索 [first, last)! 也就是左闭右开区间。

于是一开始我理所当然地对 last 减一, 以便在闭区间查找,贯彻迭代思想:

bool BinarySearch(
    const std::vector<int>::iterator first,
    const std::vector<int>::iterator last,
    const int value) {
  auto low = first;
  auto high = last - 1;  // Note that range is [first, last)
  auto mid = low + (high - low) / 2;
  while (*mid != value) {
    if (low == high) return false;
    if (*mid < value) {
      low = mid + 1;
    } else {
      high = mid;
    }
    mid = low + (high - low) / 2;
  }
  return true;
}

不过后来再写一遍时,我暗暗发觉,光用安全的迭代器还不够,还要搞明白:STL 为什么会有尾后迭代器?以及大多算法接收的迭代器范围是左闭右开的?

花了一小时半琢磨并动手重构二分查找实现,终于恍然大悟:左闭右开区间可以做到真正的对称

先单刀直入地判断 *middle 是否等于 value, 若等于,返回;若小于,那么就说明目标应该在 [low, middle) 之间;同理,大于时,那么比起说它在 (middle, high) 区间,倒不如说它是[middle + 1, high)! 再次,二分算法保证了区间始终向目标值逼近,且由于区间的两端始终是整型,于是要么直到收敛成的迭代器区间 [middle, high) 中的 *middle 刚好等于 value, 返回 true; 要么直到 high - low 等于一步步长,那么由于 (high - low) / 2 会返回 0, 于是 不光 high = middle 表达式里的左右值都等于 low, 而且还最终收敛成了空集 [middle, middle)(反观闭区间,就可没法得到空集), 返回 false. 于是左闭右开区间与二分查找配合得多么天衣无缝!

于是不再对 last 减一,重构其:

bool BinarySearch(
    const std::vector<int>::iterator first,
    const std::vector<int>::iterator last,
    const int value) {
  auto low = first;
  auto high = last;
  auto middle = first + (last - first) / 2;
  while (low != high) {
    if (*middle == value) return true;
    if (*middle > value) {
      high = middle;
    } else {
      low = middle + 1;
    }
    middle = first + (last - first) / 2;
  }
  return false;
}

不过还没完,虽然我在形参都用上 const 修饰,不过我发现 STL 算法 API 并没有这修饰,而且 const_iterator 迭代器更保证了随便你怎么折腾迭代器,爱指向哪就指向哪,但始终不能让你修改指向的值。所以为了进一步改善可读性,不再创建临时变量,直接把形参声明为 const_iterator 并拿来用:

bool BinarySearch(
    std::vector<int>::const_iterator first,
    std::vector<int>::const_iterator last,
    const int value) {
  std::vector<int>::const_iterator middle;
  while (first != last) {
    middle = first + (last - first) / 2;
    if (*middle == value) return true;
    if (*middle > value) last = middle;
    if (*middle < value) first = middle + 1;
  }
  return false;
}

超・完美无瑕之二分查找 C++ 实现!再来个贯彻递归思想的,很适合用在面试中,够一气呵成:

bool BinarySearch(
    const std::vector<int>::const_iterator first,
    const std::vector<int>::const_iterator last,
    const int value) {
  if (first == last) return false;
  auto middle = first + (last - first) / 2;
  if (*middle == value) return true;
  if (*middle > value) return BinarySearch(first, middle, value);
  if (*middle < value) return BinarySearch(middle + 1, last, value);
}

再次,我发现绝大部分所谓 C++ 二分查找实现,实则都是 C 实现,醉了……比如 Cee 给出的那个就是。你可以拿它与本项目中的 binary-search.ccbinary-search.h 比较,后者可是原汁原味的 C++ 实现,且压倒性地贯彻了 Google C++ Style Guide, 高下立判。

最后,正式回答 V2EX 那个问题「挖个坑,作为 python 程序员,面试时要求手写二分查找,可以说不么」:

其一,事实上这涉及了价值观上的考验,即可以看出将来你的上司要求你完成你并不喜欢的任务时,有没有服从安排的职业精神。当然,上司的安排是否合理则是另一回事了。

其二,从技术上来说,可以全面考察面试者:

  • 是否深刻理解了大多编程语言的「左闭右开」特征。
  • 有无运用递归、尾递归或迭代的思想。此外据我所知,Python 不支持尾调用优化。
  • 编码风格上的可读性、一致性和规范性表现如何。
  • 会不会用 first + (last - first) / 2 代替 (first + last) / 2.
  • 能否避免编程新手的常见错误。
  • 对 Python 特性运用得如何,其实我不太熟悉 Python, 但 @farseerfc 指出「python 有 slice 有 range 有 generator ,不需要C++迭代器這麼古老的語法……」;于是如果若改用 C++ 实现,要看他会不会用更高明的迭代器,const_iterator 以及其它特性等等。

吃我大模板啦!

本项目的目标之一是用 C++ 实现足以与 STL 媲美的数据结构与算法,为了泛型,自然避免不了要用模板。最近总算能小试牛刀了。

首先我犯了极其很常见的错误:把函数模板的声明与实现分离,分别放到头文件和 cc 文件里。解决也不难,直接把函数模板的定义放到头文件里去就行了。另外我借此反思了 C/C++ 编译模型,当时我有点困惑:大家都说声明与实现要分离,有助于信息隐蔽,即只要向客户提供头文件就可以了,不怕后者看到其具体实现,但如今模板的声明和定义必须放到同一个头文件里,不就违背了信息隐蔽了吗?bombless 对此解释得好,其实这真和信息隐蔽没有太大的关系,因为一样可以不用提供头文件,只要向客户提供二进制包以及文档就可以了。

此外,我通过用模板参数代替迭代器来重写 binary_search 实现时,注意后者在 STL 标准里的声明是:

template< class ForwardIt, class T >
bool binary_search( ForwardIt first, ForwardIt last, const T& value );

我有点意外,按我对 iterator tags 的理解,ForwardIt 可没法支持 auto middle = first + (last - first) / 2; 之类的表达式吧?为此我不光特地查了 std::binary_search 在 libstdc++ 上的具体实现,还跑去 stackoverflow 提问。Peter 的答案启发了我,原来还有 std::distancestd::advance 这么便利的好东西。此外我发现写模板库所要用到的技术果然和写应用程序的不一样,有五个怀疑是 GCC 扩展的函数,还有好多下划线……

于是还是用 RandomIterator 代替 ForwardIterator 算了,毕竟我目前还没有达到熟练阅读甚至编写模板库的阶段,也没必要。于是如今我删除了 binary_search.cc, 且 binary_search.h 的代码如下:

#ifndef BINARY_SEARCH_H_
#define BINARY_SEARCH_H_

#include <vector>

// Checks if an element equivalent to value appears within the range
// [first, last). There is std::binary_search exists, so do not use namespace
// std.

template <typename RandomrIterator, typename T>
bool BinarySearch(
    RandomrIterator first,
    RandomrIterator last,
    const T value) {
  if (first == last) return false;
  auto middle = first + (last - first) / 2;
  if (*middle == value) return true;
  if (*middle > value) return BinarySearch(first, middle, value);
  if (*middle < value) return BinarySearch(middle + 1, last, value);
}

namespace iteration {

template <typename RandomrIterator, typename T>
// first and last are not const.
bool BinarySearch(
    RandomrIterator first,
    RandomrIterator last,
    const T value) {
  RandomrIterator middle;
  while (first != last) {
    middle = first + (last - first) / 2;
    if (*middle == value) return true;
    if (*middle > value) last = middle;
    if (*middle < value) first = middle + 1;
  }
  return false;
}

}  // namespace iteration

#endif  // BINARY_SEARCH_H_

不出意外,对其它数据结构与算法也如法炮制,就不再一一说明了。

LeetCode 278

First Bad Version

这是我在 LeetCode 上的处女题,大概也是四年以来首次正经八百地刷在线算法题。本来想直接上我最喜欢的 C++, 但一看官方预先提供的数据结构,是一个C的传统数组,其索引值类型自然也是 int, 无奈只好改上C。希望将来掌握 Python 后再写出漂亮的代码。

此外,我还突然发现,与中规中矩地实现教科书上的数据结构与算法不同,它还蛮考验阅读能力的,我甚至运用上了批判性思维。

"Suppose you have n versions [1, 2, ..., n] and you want to find out the first bad one, which causes all the following ones to be bad." 有两个描述性假设:n 不小于 1; bad version 一定存在。所以一开始就得先判断 version 数量为一的极端情况,以及第一个版本是不是 bad. 其实这两种判断合并为一就行:if (isBadVersion(1)) return 1;. 且也再也不用担心 isBadVersion(middle - 1) 越界访问了。

”You should minimize the number of calls to the API.“ 说明了得用时间复杂度为 NlogN 的二分搜索,且不得用递归,得迭代。

我的C代码如下:

// Forward declaration of isBadVersion API.
bool isBadVersion(int version);

int firstBadVersion(int n) {
    if (isBadVersion(1)) return 1;
    int low = 1;
    int high = n + 1;
    int middle = low + (high - low) / 2;
    while (low != high) {
        if (isBadVersion(middle)) {
            if (isBadVersion(middle - 1)) {
                high = middle;
            } else {
                return middle;
            }
        } else {
            low = middle + 1;
        }
        middle = low + (high - low) / 2;
    }
}

第一次提交,一下 Accepted, 我击败了 80% LeetCode 用户!「LeetCode 处女提交 Accepted」成就达成,爽!关键还是用左闭右开区间逼近,且索引特别地从 1 开始。

不过其实还有更高明的解法,它不像我用上了那么多的判断分支,而是直接返回 low. 我琢磨了一会,发现在如此的迭代规律下,[low, high) 始终会收敛到第一个 bad version, 上一个索引值必然是 good version, 所以直接返回其索引值就可以了,于是这题真切地考查了高等数学中的区间逼近思想啊,妙!

TODO: 有待确认时间复杂度排名。

LeetCode 35

Search Insert Position

预先提供的 C++ 代码虽然还是令我抓狂的类成员函数,不过参数好歹是够现代的 vector 容器了,可以直接上迭代器,噢也!

其实思路与 ## LeetCode 278 一样,[low, high) 始终会收敛到不小于 target 的区间,于是 *low 便是第一个大于等于 target 的迭代器,所以直接返回 low - nums.begin() 即可。不得不说迭代器比传统数组索引值好用多了,nums.end() - nums.begin() 不光可以返回容器长度,first_iterator - second_iterator 更是一个有效的数组索引值。

class Solution {
public:
    int searchInsert(vector<int>& nums, int target) {
        auto low = nums.begin();
        auto high = nums.end();
        auto middle = low + (high - low) / 2;
        while (low < high) {
            if (*middle == target) return (middle - nums.begin());
            if (*middle > target) high = middle;
            if (*middle < target) low = middle + 1;
            middle = low + (high - low) / 2;
        }
        return (low - nums.begin());
    }
};

最后再吐槽下 LeetCode 这预先提供的代码:为什么非要实现为类成员函数?害我在 test 代码里大费周章地定义一个 Solution 对象,再调用其方法;此外,我发现 LeetCode 没有明确说明能不能用标准库,我只能不敢用;最后,这代码隐式地调用了 using namespace std, 这就与我的编程规范格格不入了。

LeetCode 34

Search for a Range

首先直接用二分查找找到第一个满足 *middle == targetmiddle, 否则直接返回 [-1, -1] 即可。接着分别用 [low, middle) 和 [middle + 1, high) 逼近一个所有值均等于 *middle 的最小区间和一个所有值大于 *middle 的最大区间。具体逼近法(迭代规律)直接参见代码:

class Solution {
public:
    vector<int> searchRange(vector<int>& nums, int target) {
        auto low = nums.begin();
        auto high = nums.end();
        auto middle = low + (high - low) / 2;
        while (low < high) {
            if (*middle == target) {
                auto temporary_high = middle;
                auto temporary_middle = low + (temporary_high - low) / 2;
                while (low < temporary_high) {
                    if (*temporary_middle >= target) {
                        temporary_high = temporary_middle;
                    } else {
                        low = temporary_middle + 1;
                    }
                    temporary_middle = low + (temporary_high - low) / 2;
                    
                }
                auto temporary_low = middle + 1;
                temporary_middle = temporary_low + (high - temporary_low) / 2;
                while (temporary_low < high) {
                    if (*temporary_middle <= target) {
                        temporary_low = temporary_middle + 1;
                    } else {
                        high = temporary_middle;
                    }
                    temporary_middle = temporary_low + (high - temporary_low) / 2;
                }
                return vector<int>{low - nums.begin(), high - nums.begin() - 1};
            }
            if (*middle < target) low = middle + 1;
            if (*middle > target) high = middle;
            middle = low + (high - low) / 2;
        }
        return vector<int>{-1, -1};
    }
};

别把区间错返回成 vector<int>{low - nums.begin(), high - nums.begin()}; 就好。

LeetCode 74

Search a 2D Matrix

一开始想出了一个失败的迭代法:

class Solution {
public:
    bool searchMatrix(vector<vector<int>>& matrix, int target) {
        auto row_low = matrix.begin();
        auto row_high = matrix.end();
        auto row_middle = row_low + (row_high - row_low) / 2;
        while (row_low < row_high) {
            if (row_middle->front() == target || row_middle->back() == target)
                return true;
            if (row_middle->back() < target) row_low = row_middle + 1;
            if (row_middle->front() > target) row_high = row_middle;
            row_middle = row_low + (row_high - row_low) / 2;
        }
        auto column_low = row_middle->begin();
        auto column_high = row_middle->end();
        auto column_middle = column_low + (column_high - column_low) / 2;
        while (column_low < column_high) {
            if (*column_middle == target) return true;
            if (*column_middle < target) column_low = column_middle + 1;
            if (*column_middle > target) column_high = column_middle;
            column_middle = column_low + (column_high - column_low) / 2;
        }
        return false;
    }
};

row_middle->front() < target < row_middle->back() 时,就迭代不动了。

斟酌再三,决定还是用以往的三个判断分支,且特别地,在 if (row_middle->front() < target) 时,直接 row_low = row_middle;, 按我的直觉,这是无害的。不过这么一来,row_lowrow_high 收敛到 middle 右边的区间不再是 [middle + 1, high), 而是 [middle, high) 了。于是在 row_lowrow_high 之间只剩一步步长的极端情况下,若仍旧 *row_middle < target 的话,这时已经没法再进一步收敛成空区间了,即 row_low 永远保持原位置不变,毕竟 (row_high - row_low) / 2 恒为 0.

所以 while 循环条件改为 row_low < row_high - 1 即可:

class Solution {
public:
    bool searchMatrix(vector<vector<int>>& matrix, int target) {
        if (matrix.size() == 0 || matrix.front().size() == 0) return false;
        if (matrix.front().front() > target || matrix.back().back() < target) return false;
        auto row_low = matrix.begin();
        auto row_high = matrix.end();
        auto row_middle = row_low + (row_high - row_low) / 2;
        while (row_low < row_high - 1) {
            if (row_middle->front() == target) return true;
            if (row_middle->front() < target) row_low = row_middle;
            if (row_middle->front() > target) row_high = row_middle;
            row_middle = row_low + (row_high - row_low) / 2;
        }
        auto column_low = row_low->begin();
        auto column_high = row_low->end();
        auto column_middle = column_low + (column_high - column_low) / 2;
        while (column_low < column_high) {
            if (*column_middle == target) return true;
            if (*column_middle < target) column_low = column_middle + 1;
            if (*column_middle > target) column_high = column_middle;
            column_middle = column_low + (column_high - column_low) / 2;
        }
        return false;
    }
};

显然,它会收敛到步长为一的 [row_low, row_high) 的区间上,且 target 值就在 row_low 中。于是接着便是顺理成章地在 row_low 上继续二分查找了。此外,可以在函数开头加入对二维长度之一为零,最左上角元素大于 target 或最右下角元素小于 target 极端情况的判断。

其实这题还有不少解法,毕竟没有明确要求时间复杂度。要么暴力穷举所有元素,时间复杂度为 n^2; 要么从从右上角元素开始穷举,时间复杂度 n; 先纵向二分查找,再横向二分查找,也就是我的解法,时间复杂度 2 * log_n; Push 所有元素进一维数组,接着一次性二分查找,时间复杂度 n + log_n. 我的解法击败了 40.73% C++ 实现。

⚠️ **GitHub.com Fallback** ⚠️