K个最小?别傻傻地一个个找了!😎 聊聊二分答案的降维打击(2040. 两个有序数组的第 K 小乘积)

69 阅读9分钟

K个最小?别傻傻地一个个找了!😎 聊聊二分答案的降维打击

Hey,各位奋斗在一线的码农兄弟姐妹们!我是你们的老朋友,一个热爱性能优化的代码“老兵”。今天想跟大家聊聊最近我在项目中遇到的一个棘手问题,以及我是如何从“头皮发麻”到“恍然大悟”的。故事的开始,还得从一个看似简单的需求说起。

我遇到了什么问题?🤔

我目前在一家音乐流媒体公司工作,我们最近在做一个非常酷的功能:“音乐品味黑洞分析”

这个功能的目标是,找出我们的推荐系统中最不成功的 k 次推荐,也就是用户和歌曲的“最不匹配”组合。我们的系统里有两个核心数组:

  1. userPreferenceScores:一个有序数组,记录了某个用户的品味偏好分值(可正可负)。
  2. songFeatureScores:一个有序数组,记录了歌曲库里所有歌曲的某个特征分值(同样可正可负)。

我们定义一对(用户,歌曲)的“不匹配度”为 userPreferenceScore * songFeatureScore。这个值越小(负得越多),说明越不匹配。现在,产品经理的需求来了:“请找出不匹配度排在第 k 位的值是多少”。

比如 k=1 就是找出最不匹配的那一对,k=1000 就是找出第1000不匹配的那一对。

我的第一反应:最小堆!这题我熟!😉

看到“第k小”这三个字,我的DNA动了!这不就是经典的 Top-K 问题嘛(2040. 两个有序数组的第 K 小乘积)?用一个最小堆(PriorityQueue)来解决简直是教科书式的操作。

我的思路是这样的(也就是你给出的第一个解法):

  1. userPreferenceScores 数组里的每个用户,都看作一条“待处理流水线”。
  2. 对于每个用户 u,他与所有歌曲的乘积中,最小的那个(取决于 u 的分数是正还是负)作为这条流水线的“头”。
  3. 把所有流水线的“头”都扔进一个最小堆里。
  4. 然后,循环 k 次:
    • 从堆顶弹出一个全局最小的乘积。
    • 再把这个乘积所在流水线的“下一个”元素扔进堆里。
  5. k 次弹出的,不就是答案嘛!

我啪啪啪地敲完了代码,逻辑清晰,自我感觉良好。在小数据集上跑得飞快,简直完美!

// 这是我最初引以为傲的代码
// ... (此处省略与你提供的第一份代码完全相同的最小堆实现) ...

然后,我把它部署到了预发环境,准备接受现实的检验。结果…… “Time Limit Exceeded” (超出时间限制) 的红色警告亮瞎了我的眼。

我懵了,为什么?问题出在哪?我看了看生产环境的数据规模:userPreferenceScoressongFeatureScores 的长度都是几万级别,而 k 的值,竟然能达到几百万!

我的算法复杂度是 O(N + k*logN),当 k 巨大无比时,k*logN 这个操作就成了性能的致命瓶颈。我们需要循环几百万次,每次还要维护堆,服务器当然不干了。我意识到,这条路走不通,我不能一个一个地去“数”到第k个。

恍然大悟的瞬间:换个脑子!🤯

就在我抓耳挠腮的时候,我突然想到了一个关键点:我需要的是第k小的那个值,而不是前k小的所有值。我是不是可以换个问法?

不要问:“第k小的值是多少?”

而是问:“我猜一个值 X,这个 X 是不是第k小的值?”

这怎么判断呢?很简单,我只要能快速算出“有多少个乘积小于等于 X”,记为 count

  • 如果 count < k,说明我猜的 X 太小了,真正的第k小的值比 X 大。
  • 如果 count >= k,说明我猜的 X 太大了或者正好,真正的第k小的值可能就是 X,或者比 X 更小。

这……这不就是 二分查找 吗?!我们不是在数组索引上二分,而是在答案的取值范围上二分!

这个思路的转变,简直是降维打击!

用“二分答案”搞定它!

1. 确定答案范围 分数的范围是 [-100000, 100000],所以乘积的范围大约在 [-10^10, 10^10]。这就是我们二分查找的 lowhigh

2. 核心:countLessEqual(mid) 函数 这是整个算法的灵魂。对于一个我们猜的中间值 mid,如何高效计算有多少对乘积 u*s <= mid

这里有个小坑!u 的符号会影响不等式:

  • u > 0: 我们要找 s <= mid / u
  • u < 0: 除以负数,不等号反向!我们要找 s >= mid / u
  • u = 0: 乘积是0。如果 mid >= 0,所有歌曲都满足;否则都不满足。

因为 songFeatureScores 数组是有序的,所以对于每个 u,我们可以在 songFeatureScores 里面再用一次二分查找,来快速找到满足条件的歌曲数量。

3. 最终代码 下面就是我重构后的代码,它像一位优雅的刺客,精准而高效:

class Solution {
    // 主函数,在答案范围[-10^10, 10^10]上进行二分
    public long kthSmallestProduct(int[] nums1, int[] nums2, long k) {
        long low = -10000000001L; // 略小于理论最小值
        long high = 10000000001L; // 略大于理论最大值
        long ans = -1;

        while (low <= high) {
            long mid = low + (high - low) / 2;
            // check函数:计算有多少乘积 <= mid
            long count = countLessEqual(nums1, nums2, mid);

            if (count >= k) {
                // mid 太大或刚刚好,尝试更小的值
                ans = mid;
                high = mid - 1;
            } else {
                // mid 太小了,需要增大
                low = mid + 1;
            }
        }
        return ans;
    }

    /**
     * 神奇的 check 函数:计算小于等于 val 的乘积有多少个
     * 这是整个算法的核心,它本身又利用了二分查找来加速计数
     */
    private long countLessEqual(int[] nums1, int[] nums2, long val) {
        long totalCount = 0;
        int m = nums2.length;

        for (int x : nums1) {
            if (x > 0) {
                // 👉 x > 0: 找 y <= val / x
                // 在 nums2 中二分查找满足条件的 y 的数量
                int l = 0, r = m - 1, boundary = -1;
                while (l <= r) {
                    int midIdx = l + (r - l) / 2;
                    if ((long) x * nums2[midIdx] <= val) {
                        boundary = midIdx;
                        l = midIdx + 1; // 尝试找更大的
                    } else {
                        r = midIdx - 1;
                    }
                }
                totalCount += (boundary + 1);
            } else if (x < 0) {
                // 👉 x < 0: 找 y >= val / x (注意不等号反向!)
                int l = 0, r = m - 1, boundary = m;
                while (l <= r) {
                    int midIdx = l + (r - l) / 2;
                    if ((long) x * nums2[midIdx] <= val) {
                        r = midIdx - 1; 
                    } else {
                        boundary = midIdx;
                        l = midIdx + 1; // 尝试找更小的
                    }
                }
                totalCount += (m - boundary);
            } else { // x == 0
                // 👉 x = 0: 乘积为0. 只要 val >= 0,所有组合都满足
                if (val >= 0) {
                    totalCount += m;
                }
            }
        }
        return totalCount;
    }
}

当我用这个新方案再次提交时,Accepted!秒过! 😂 那种感觉,就像给一辆拖拉机换上了喷气式引擎!

举一反三:这种思维还能用在哪?

这种“二分答案”的思想非常强大,是解决“求第k个/最大化最小值/最小化最大值”这类问题的神器。只要你发现一个问题的答案具有单调性(即,如果 X 满足条件,那么所有 >X 的值也满足,或者反之),就可以尝试用它。

比如下面这些经典场景:

  1. 最大化最小值 (Aggressive Cows / 安置路灯): 要在一条街上放 k 盏灯,让相邻灯之间的最小距离尽可能大。我们可以二分这个“最小距离”,然后检查在这个距离下能否放下 k 盏灯。
  2. 最小化最大值 (Split Array Largest Sum): 把一个数组分成 m 个子数组,让这些子数组和的最大值尽可能小。我们可以二分这个“子数组和的最大值”,然后检查在这个限制下,最少需要把原数组分成几段。
  3. 送货问题:快递员要在一天内送完所有包裹,求最慢的速度是多少。我们可以二分“速度”,然后检查在这个速度下,是否能在规定时间内送完。

趁热打铁:更多练手好题!💪

说得再多,不如亲手敲一遍来得实在。为了让大家彻底掌握这个强大的思想,我特地从力扣(LeetCode)上精选了几道异曲同工之妙的经典题目。刷完它们,你绝对能把“二分答案”刻进自己的DNA里!

  1. LeetCode 378. 有序矩阵中第 K 小的元素

    • 题目简介:给你一个 n x n 矩阵,其中每行和每列元素均按升序排序,找到矩阵中第 k 小的元素。
    • 解题思路:和我们今天的问题一样,不要去用堆一个个找!直接对元素的值进行二分。check(mid) 函数就是去计算矩阵里有多少元素小于等于 mid。利用矩阵的行列有序性,可以从左下角或右上角出发,在 O(n) 时间内完成计数。是不是很酷!
  2. LeetCode 719. 找出第 K 小的数对距离

    • 题目简介:给定一个整数数组,返回所有数对之间距离的第 k 小的值。距离定义为 |nums[i] - nums[j]|
    • 解题思路:对数对的距离进行二分。check(mid) 函数就是计算有多少个数对的距离小于等于 mid。这需要先对原数组排序,然后用双指针或对每个元素再进行一次二分来高效计数。
  3. LeetCode 410. 分割数组的最大值

    • 题目简介:这是“最小化最大值”问题的完美典范!将一个非负整数数组分割成 m 个非空的连续子数组,要求使得这 m 个子数组各自和的最大值最小。
    • 解题思路:对“子数组和的最大值”进行二分。check(mid) 函数是判断:能否将数组分割成不超过 m 段,且每段的和都不超过 mid?这可以用贪心算法在 O(n) 时间内解决。
  4. LeetCode 1552. 两球之间的磁力

    • 题目简介:这是“最大化最小值”问题的完美典范!在数轴上有一些篮子,你要放 m 个球到篮子里,要求任意两个球之间的最小距离尽可能大。
    • 解题思路:对“最小距离”进行二分。check(mid) 函数是判断:能否放下 m 个球,使得它们之间的距离都至少为 mid?同样可以用贪心算法在 O(n) 时间内解决。

结语 🙏

从一个直观但会超时的堆方法,到最终高效的二分答案法,这次经历让我再次深刻体会到:

当问题规模变得巨大时,改变思考问题的角度,往往比优化代码细节更重要。

希望我的这次“踩坑”和“顿悟”经历能对你有所启发。下次再遇到 "Top-K" 或 "Max-Min" 这类问题时,除了堆,不妨也问问自己:“我能猜一下答案吗?” 😉

Happy Coding