Java的Comparator在算法题中的使用

393 阅读2分钟

函数式接口

首先看Comparator接口的具体结构:

image.png

然后再看函数式接口的定义:

函数式接口(Functional Interface)就是一个有且仅有一个抽象方法,但是可以有多个非抽象方法的接口。

可以发现在Comparator接口中,除了int compare(T, T)方法为抽象方法没有具体的实现,其余的方法都给出了具体实现。所以Comparator是一个函数式接口,我们可以使用lambda表达式来让代码更简洁。

如何在算法题中使用Comparator

先从简单的数组开始,如果要对一个int类型的数组nums进行升序排序,我们都会写出下面的代码:

Arrays.sort(nums);

如果要求改了,需要对nums进行降序排序,那就可以用自定义的Comparator来实现了。对于int compare(T o1, T o2)方法的规定是:返回正数,则o1排在o2的后面;返回负数,则o1排在o2的前面;返回0,则o1o2相等。

int[] nums = new int[]{2, 4, 1, 5, 3};
Integer[] arr = new Integer[nums.length];

//int数组转换成Integer数组
for (int i = 0; i < nums.length; i++) {
    arr[i] = nums[i];
}

//实现了自定义的compare方法
Arrays.sort(arr, (a, b) -> {
    return b - a;
});

System.out.println(Arrays.toString(arr));//[5, 4, 3, 2, 1]

特别需要注意的是,int compare(T o1, T o2)比较的是两个引用类型,对于基本数据类型是无法进行比较的,所以要把int数组转换成Integer数组。

//直接对nums进行比较会导致编译不通过
int[] nums = new int[]{2, 4, 1, 5, 3};
Arrays.sort(nums, (a, b) -> { 
    return b - a;
});

再考虑一种极端情况:当前正在进行比较的两个数相减之后超过了Integer.MAX_VALUE或小于Integer.MIN_VALUE,由于int compare(T o1, T o2)规定返回int类型,会导致排序结果出现异常。

//数组中有元素非常小
int[] nums = new int[]{2, 4, Integer.MIN_VALUE, -5, 3};

Integer[] arr = new Integer[nums.length];
for (int i = 0; i < nums.length; i++) {
    arr[i] = nums[i];
}
Arrays.sort(arr, (a, b) -> b - a);
System.out.println(Arrays.toString(arr));//错的排序结果:[3, -5, -2147483648, 4, 2]

可采取一种健壮性更好的写法来规避这种超过int大小限制的问题

int[] nums = new int[]{2, 4, Integer.MIN_VALUE, -5, 3};
Integer[] arr = new Integer[nums.length];
for (int i = 0; i < nums.length; i++) {
    arr[i] = nums[i];
}

//并不是简单粗暴地返回两数之差,而是判断大小后返回1,0,-1
Arrays.sort(arr, (a, b) -> {
    if (a > b) {
        return -1;
    } else if (a < b) {
        return 1;
    } else {
        return 0;
    }
});

System.out.println(Arrays.toString(arr));//对的排序结果:[4, 3, 2, -5, -2147483648]

leetcode 56:合并区间

class Solution {
    public int[][] merge(int[][] intervals) {
        //重写了compare方法,使得多个区间根据区间开始处的大小进行升序排序
        //需要注意,这里是对二维int数组的每一行进行排序
        //a、b是指向两个行的引用类型,所以不需要转换成Integer数组
        Arrays.sort(intervals, (a, b) -> {
            return a[0] - b[0];
        });
        int[][] res = new int[intervals.length][2];
        
        //遍历这些区间,根据当前区间的开始处的大小,选择合并或新增区间
        int index = -1;
        for (int[] interval : intervals) {
            if (index == -1 || interval[0] > res[index][1]) {
                res[++index] = interval;
            } else {
                res[index][1] = Math.max(res[index][1], interval[1]);
            }
        }
        
        return Arrays.copyOf(res, index + 1);
    }
}