一道二分题引发的思考

342 阅读3分钟

二分思路简单,但是边界难抠。有时又容易进入死循环。所以针对这一问题,我们建议整理出一套模板,经过短暂时间内的调试或者直接AC可以提高面试时的认可度。

题目描述

给定一个按照升序排列的长度为n的整数数组,以及 q 个查询。
对于每个查询,返回一个元素k的起始位置和终止位置(位置从0开始计数)。
如果数组中不存在该元素,则返回“-1 -1”。

输入格式
第一行包含整数n和q,表示数组长度和询问个数。
第二行包含n个整数(均在1~10000范围内),表示完整数组。
接下来q行,每行包含一个整数k,表示一个询问元素。

输出格式
共q行,每行包含两个整数,表示所求元素的起始位置和终止位置。
如果数组中不存在该元素,则返回“-1 -1”。

数据范围

1≤n≤100000

1≤n≤100000

1q10000

1q10000

1≤k≤10000

1≤k≤10000

case:输入样例:

6 3

1 2 2 3 3 4

3

4

5

输出样例:

3 4

5 5

-1 -1

本题是练习二分很好的一道题目,二分程序虽然简单,但是如果写之前不考虑好想要查找的是什么,十有八九会是死循环或者查找错误,就算侥幸写对了也只是运气好而已。用二分去查找元素要求数组的有序性或者拥有类似于有序的性质。对本题而言,一个包含重复元素的有序序列,要求输出某元素出现的起始位置和终止位置,翻译一下就是:在数组中查找某元素,找不到就输出-1,找到了就输出不小于该元素的最小位置和不大于该元素的最大位置。所以,需要写两个二分,一个需要找到>=x的第一个数,另一个需要找到<=x的最后一个数。查找不小于x的第一个位置,较为简单:

int l = 0 ,r = len-1;
while (l<r){
    int mid = l+r>>1;
    if (arr[mid]<value){
        l = mid + 1;
    } else {
        r = mid;
    }
}

首先需要确定的是check条件的问题。首先我们需要确定左端第一个不小于value的元素。当a[mid]<x时,令l = mid + 1,mid及其左边的位置被排除了,可能出现解的位置是mid + 1及其后面的位置;当a[mid] >= x时,说明mid及其左边可能含有值为x的元素;当查找结束时,l与r相遇,l所在元素若是x则一定是x出现最小位置,因为l左边的元素必然都小于x。

查找不大于x的最后一个位置,便不容易了:

int l1 = l, r1 = n;
while (l1 + 1 < r1) {
    int mid = l1 + r1 >> 1;
    if (a[mid] <= x)  l1 = mid;
    else    r1 = mid;
}

要查找不大于x的最后一个位置,当a[mid] <= x时,待查找元素只可能在mid及其后面,所以l = mid; 当a[mid] > x时,待查找元素只会在mid左边,令r = mid。为什么不能令r=mid-1呢?因为如果按照上一个二分的写法,循环判断条件还是l < r,当只有两个元素比如2 2时,l指向第一个元素,r指向第二个元素,mid指向第一个元素,a[mid] <= x,l = mid还是指向第一个元素,指针不移动了,陷入死循环了,此刻l + 1 == r,未能退出循环。

那么直接把循环判断条件改成l + 1 < r呢?此时一旦只有两个元素,l和r差1,循环便不再执行,查找错误。 所以这里出现了二分的典型错误,l == r作为循环终止条件,会出现死循环,l + 1 == r作为循环终止条件,会出现查找错误。

问题如何解决,一种方法就是将查找的区间设置为左闭右开,比如待查找元素在[0,n - 1]范围内,可以写成[0,n),令r = n,这时候只有两个元素时,r是取最右边元素的后一个位置的,l和r相差2,还会执行循环。 现在再来理解上一段的r1 = mid,说明a[mid] > x时,r = mid就表示待查找元素会是在r的左边,因为r是开区间。上面这种写法修改了循环条件使得二分不会死循环,修改了区间的开闭性使得不会查找错误。另一种解决办法就是:

int l = 0, r = n - 1;
while (l < r)
 {
        int mid = l + r + 1 >> 1;
        if (a[mid] <= x) l = mid;
        else r = mid - 1;
 }

推荐这种解法,因为这种办法与求左边界时的变量设置一致,方便理解与记忆。 变化点在于check函数的改变以及mid取值的不同而已。

因为这时就算只有两个元素,l + 1 = r,mid = l,a[mid]小于x时l是会+1的,不小于x时r = mid也会缩小区间。 而查找不大于x的最后一个位置之所以会死循环是因为编程语言里面除以2的下取整性,试想下如果l + 1 = r时,mid = (l + r) / 2 = l,一旦a[mid] <= x,l = mid = l,区间并没有缩小,从而陷入死循环;如果一开始取mid为r,一旦a[mid] <= x,l = mid = r,区间缩小,否则r = mid - 1 = l区间缩小,l都会与r相遇,就不会陷入死循环了。 如何做到上取整呢?只需要取mid时在l + r后面再加1即可,这里l和r都是闭区间,所以当a[mid] > x时,r = mid - 1.

综上,给出我的AC代码

import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

/**
 * @author qule
 * @date 2021/4/28 6:45 下午
 * @since
 */
public class findPosition {
    /**
     * 1. 给定一个浮点数 n,求它的三次方根。
     * 2. 给定一个按照升序排列的长度为 n 的整数数组,以及 q 个查询。
     * 对于每个查询,返回一个元素 k 的起始位置和终止位置(位置从 0 开始计数)。
     * 如果数组中不存在该元素,则返回 -1 -1。
     * 6. 逆序对数量
     * @param args
     */
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int query = sc.nextInt();

        int arr[] = new int[n];
        for (int i=0;i<n;i++){
            arr[i]=sc.nextInt();
        }
        for (int i = 0;i<query;i++){
            int value = sc.nextInt();
            List<Integer> result = binarySearchPosition(arr,value);
            System.out.println(result.get(0)+" "+result.get(1));
        }

        /**
         public static void main(String args[]) throws IOException{
         BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
         int n = Integer.parseInt(reader.readLine());
         String[] arrStr = reader.readLine().split(" ");
         int arr[]=new int[n];
         for (int i = 0; i < n; i++)
         arr[i] = Integer.parseInt(arrStr[i]);
         nixudui(arr,0,n-1);
         // 打印结果
         System.out.println(result);
         // 关闭输入流
         reader.close();
         }
         */
    }

    public static List<Integer> binarySearchPosition(int arr[],int value){
        List<Integer>result = new ArrayList<>();
        int len = arr.length;
        // 先找最左端的数据,再找最右边的
        int l = 0 ,r = len-1;
        while (l<r){
            int mid = l+r>>1;
            if (arr[mid]<value){
                l = mid + 1;
            } else {
                r = mid;
            }
        }
        if (arr[l]!=value){
            result.add(-1);
            result.add(-1);
            return result;
        }
        result.add(l);
        l = 0;
        r = len-1;
        while (l<r){
            int mid = l+r+1>>1;
            if (arr[mid]<=value){
                l = mid ;
            } else {
                r = mid-1;
            }
        }
        result.add(r);
        return result;
    }

}