一个模板解决二分搜索的7种变体

最近开始尝试参加 LeetCode 周赛,其中碰到一道题 5219. 每个小孩最多能分到多少糖果 需要使用二分法(可惜刷题太少了当时并没有看出这个套路)。众所周知,二分搜索法变种有很多,而写出一个正确没有 bug 的二分法还是有点难度的。又因为自己平常写二分法的时候总是需要在一些关键点上停下来思考一会儿,所有干脆花点时间整理一个涵盖所有变种的二分模板背下来,(脑)空间换时间!

直接上模板:

int l=<搜索空间起点>, r=<搜索空间终点>, t=<搜索目标的值>;
while (l < r) {
    int mi = <中点逻辑>;
    if (<符合搜索目标>) <符合条件抉择>;
    else <不符合条件抉择>;
}
<l的溢出判断以及目标边界判断>;

其中,尖括号包裹的内容针对不同变种有所不同。

当我们写一个二分法的时候,首先确定的是搜索目标以及目标值 t。而二分法的搜索目标可以总结为以下 7 种变体(术语按照个人喜好编的,这个不重要):

  1. 等值:最常见的,寻找目标值 t 的索引,不存在返回 -1;
  2. ceil:上界,寻找大于 t 的最小值索引;
  3. upper_ceil:若 t 存在,返回最大索引,若不存在,同 ceil;
  4. lower_ceil:若 t 存在,返回最小索引,若不存在,同 ceil。注意相当于求大于等于 t 的最小值索引;
  5. floor:下界,寻找小于 t 的最大值索引;
  6. lower_floor:若 t 存在,返回最小索引,若不存在同 floor;
  7. upper_floor:若 t 存在,返回最大索引,若不存在同 floor。

可以看到,除了等值情况,其实就是分为了上界,下界,以及分别与等值情况结合的情况。

接下来从最常见的两个变体讲起,ceil 和 floor,即上界和下界。

ceil

对于上界 ceil 变体,其搜索目标为 >t的最小值索引。(假设数组长度为 len,下同。)

其搜索空间为 [0, len],搜索空间的意思是算法执行结束后返回值的范围,其端点值也就是 l 和 r 的初始值。注意这里右空间为溢出值 len,代表了数组中所有元素都小于 t 这样一种情况,除此之外返回值必然落在数组索引范围里。

注意循环条件:

int l=<搜索空间起点>, r=<搜索空间终点>, t=<搜索目标的值>;
while (l < r) {...}

当循环退出时,l == r,所以终于不用思考用 l 还是用 r 啦!

至于循环体:

while (l < r) {
    int mi = <中点逻辑>;
    if (<符合搜索目标>) <符合条件抉择>;
    else <不符合条件抉择>;
}
<l的溢出判断以及目标边界判断>;

首先是**<中点逻辑>**,这里会有两个选择(也就是上取整和下取整),具体选择哪一个稍后解释,因为这涉及到二分法的一个天坑:

  1. int mi = l + (r - l) / 2
  2. int mi = l + (r - l + 1) / 2

然后是 **<符合搜索目标>**,直接抄写一下条件就好了,比如这里就是 arr[mi] > t

<符合条件抉择> 是啥呢?可以这样想,我们要求最小值,也就是满足条件还不够,还要尽可能的小,当然也不能排除当前的 mi。因此这里就是 r = mi,让关注的区间转向往小的区间。那 <不符合条件抉择> 反着来就好了,这里就是 l = mi+1,因为当前情况需要排除,所以 mi 需要加一。回到中点逻辑的选择,这里直接给出答案,需要使用第一种,稍后解释。

最后就是 <l的溢出判断以及目标边界判断> 了,其实这里比较简单,跳出时只有两种情况:

  1. l 溢出,也就是 l == len
  2. l 为符合条件的正确索引。

如果把溢出索引也当作正常结果,这里就可以不做任何处理,直接 return l 就好了。

组合上述抉择,最终的代码为:

int l=0, r=len, t=target;
while (l < r) {
    int mi = l + (r - l) / 2;
    if (arr[mi] > t) r = mi;
    else l = mi+1;
}
return l;

floor

有了上一节的基础,floor的分析就简单啦。先贴出代码:

int l=-1, r=len-1, t=target;
while (l < r) {
    int mi = l + (r - l + 1) / 2;
    if (arr[mi] < t) l = mi; // (1)
    else r = mi-1;
}
return l;

floor 的搜索目标是 < t的最大值。搜索空间变成了 [-1, len-1],其中 -1 表示数组所有元素都大于 t 的情况。相应的,我们找到一个满足条件的值后还不够,还需要确保这是最大值,所以 <符合条件抉择> 为:l = mi。响应的不符合的条件也要对照着修改一下。

注意中点逻辑也改了,这里选用了第二种,也就是上取整,这是为了解决一个死循环的 bug:

比如在算法的某一步,l == a, r == a+1,如果使用第一种中点逻辑,中点 mi == a,而如果算法走到 (1) 处 l = mi,也就是 l == a,于是区间经过调整后,并没有发生变化,算法陷入死循环。

把中点逻辑改为上取整就是为了解决这个问题。

lower_ceil 和 upper_floor

有了 lower 和 upper 的分析,这两个变种稍微修改一下就可以了。

对于 lower_ceil,只需要在搜索条件里加上等号即可。

int l=0, r=len, t=target;
while (l < r) {
    int mi = l + (r - l) / 2;
    if (arr[mi] >= t) r = mi;
    else l = mi+1;
}
return l;

upper_floor 同理:

int l=-1, r=len-1, t=target;
while (l < r) {
    int mi = l + (r - l + 1) / 2;
    if (arr[mi] <= t) l = mi;
    else r = mi-1;
}
return l;

等值

等值的话直接用市面上的二分模板就好了,不过为了省事,也可以直接用万能模板啦。

其实用两个的等值变种都可以,我一般用的是 lower_ceil,也就是大于等于 t 的最小值索引。

int l=0, r=len, t=target;
while (l < r) {
    int mi = l + (r - l) / 2;
    if (arr[mi] >= t) r = mi;
    else l = mi + 1;
}
return (l != len && arr[l] == t) ? l : -1;

溢出的情况说明没找到,而没溢出只能保证 l 不小于 t,需要判断是否真的相等。也就是非溢出且元素值等于 t 的情况下才说明找到了,否则返回 -1。

upper_ceil 和 lower_floor

这两个变体实际用到的不多。分别可以拿到upper和floor的结果后再进行一些额外判断即可。

比如对于 upper_ceil

int upper = upper(arr, t);
if (upper-1 >= 0 && arr[upper - 1] == t) {
    return upper - 1;
}
return upper;

也就是判断前一个元素是否等于 t。不是的话就返回upper结果。

lower_floor 也是类似的:

int floor = floor(arr, t);
if (floor+1 < arr.length && arr[floor+1] == t) {
    return floor+1;
}
return floor;

小练习

来做个练习题吧:给定一个有序数组 arr,请用 log(n) 的时间复杂度求出其中 t 的个数。

参考答案:

  1. 首先二分寻找 大于等于t的最小值索引 a,如果 a==len || arr[a] != t 说明 t 不存在,返回0。
  2. 否则二分寻找 大于t的最小值索引 b,返回 b-a。

参考代码:

int findNum(int [] arr, int target) {
    final int n = arr.length;
    int l = 0, r = n, t = target;
    while (l < r) {
        int mi = l + (r - l) / 2;
        if (arr[mi] >= t) r = mi;
        else l = mi + 1;
    }
    if (l==n || arr[l] != t) return 0;
    int left = l;
    l = 0; r = n;
    while (l < r) {
        int mi = l + (r - l) / 2;
        if (arr[mi] > t) r = mi;
        else l = mi + 1;
    }
    return l - left;
}