巧用快排,秒杀 LeetCode 973. K Closest Points to Origin

1,740 阅读3分钟

细节决定成败。一个小细节可以让代码的性能大幅提升。

最近和朋友一起研究分治算法,看过《算法导论》后觉得,纸上得来终觉浅,绝知此事要躬行啊!遂去 LeetCode 上 Divide and Conquer 这个 Topic 下,做了这道题 973. K Closest Points to Origin

本文分享作者在做题时发现的优秀代码细节,希望和大家一起吸取营养并体会其中乐趣。

阅读前建议大家先自己做一下该题,至少思考一下,然后再往下阅读,这样更容易体会到这个细节的优雅。

一、题目

973. K Closest Points to Origin

We have a list of points on the plane.  Find the K closest points to the origin (0, 0).

(Here, the distance between two points on a plane is the Euclidean distance.)

You may return the answer in any order.  The answer is guaranteed to be unique (except for the order that it is in.)

Example 1:

Input: points = [[1,3],[-2,2]], K = 1
Output: [[-2,2]]
Explanation: 
The distance between (1, 3) and the origin is sqrt(10).
The distance between (-2, 2) and the origin is sqrt(8).
Since sqrt(8) < sqrt(10), (-2, 2) is closer to the origin.
We only want the closest K = 1 points from the origin, so the answer is just [[-2,2]].

Example 2:

Input: points = [[3,3],[5,-1],[-2,4]], K = 2
Output: [[3,3],[-2,4]]
(The answer [[-2,4],[3,3]] would also be accepted.)
 
Note:
1 <= K <= points.length <= 10000
-10000 < points[i][0] < 10000
-10000 < points[i][1] < 10000

题目大意就是让找出距离原点最近的前K个点。并且无需关心结果的顺序

二、代码

闲言少叙,直接上改进前的代码:


class Solution {
    public int[][] kClosest(int[][] points, int K) {
        quickSort(points, 0, points.length - 1);
        return Arrays.copyOfRange(points, 0, K);
    }
    
    private void quickSort(int[][] a, int low, int high) {
        if(low < high) {
            int pivot = partation(a, low, high);
            quickSort(a, low, pivot - 1);
            quickSort(a, pivot + 1, high);
        }
    }
    
    private int partation(int[][] a, int low, int high) {
        int[] pivot = a[low];
        int pivotDist = dist(pivot);
        
        while(low < high) {
            while(low < high && dist(a[high]) >= pivotDist) {
                high--;
            }
            a[low] = a[high];
            while(low < high && dist(a[low]) <= pivotDist) {
                low++;
            }
            a[high] = a[low];
        }
        a[low] = pivot;
        return low;
    }

    private int dist(int[] a) {
        return a[0] * a[0] + a[1] * a[1];
    }
}

用时16ms, 反复试了几次也都是10+ms,看到 LeetCode 记录用时最短的方案是 3ms,把这个3ms的方案提交一下,现在用时4ms,大概看了一下,他也是用快速排序,和我的代码几乎一样,但是我的为什么这么慢?

三、找差距,变优秀

差距主要在 quickSort 方法,他的写法类似下面,加了一些判断,去掉了一些不必要的排序,因为题目说 "You may return the answer in any order", 所以只需找到前K个最小的即可,无需保证前K个按照从小到大的顺序。


    private void quickSort(int[][] a, int low, int high, int K) {
        if(low < high) {
            int pivot = partation(a, low, high);
            if(pivot == K) return;
            if(pivot > K) {
                quickSort(a, low, pivot - 1, K);
            }else{
                quickSort(a, pivot + 1, high, K);
            }
        }
    }

改成这样后用时也是4ms. 我对这段代码的理解:
经过一趟快排后,枢轴(支点)的左侧都小于等于它,右侧都大于等于它。
(1)如果枢轴的下标等于K, 则说明枢轴左侧的K个点即是前K个距原点距离最小的点,返回即可;
(2)如果枢轴的下标大于K, 则要想满足(1),只需将 low 至 pivot - 1 继续进行快速排序;
(3)如果枢轴的下标小于K, 则要想满足(1),只需将 pivot + 1 至 high 继续进行快速排序。

由此可见思路上的差异,我之前的思路是利用快速排序进行从小到大排序,然后取前K个,而这种做法是找恰当的枢轴,枢轴左侧的点即是要找的点,省去很多无用功,将快速排序用的恰到好处。

四、总结

解决问题时,要注意题目中的条件,巧用适当的算法。找到与优秀的差距,变成优秀。