136. 合法三元组数量计算 题解 | 豆包MarsCode AI刷题

3 阅读4分钟

问题描述

小C、小U 和小R 三个好朋友喜欢做一些数字谜题。这次他们遇到一个问题,给定一个长度为n的数组a,他们想要找出符合特定条件的三元组 (i, j, k)。具体来说,三元组要满足 0 <= i < j < k < n,并且 max(a[i], a[j], a[k]) - min(a[i], a[j], a[k]) = 1,也就是说,最大值与最小值之差必须为1。
他们决定请你帮忙编写一个程序,计算符合这个条件的三元组数量。


测试样例

样例1:

输入:a = [2, 2, 3, 1]
输出:2

样例2:

输入:a = [1, 3, 2, 2, 1]
输出:5

样例3:

输入:a = [1, 3, 2, 2, 1, 2]
输出:12


思路解析

本题可以分解为找出数组中所有的三元组,其中最大的数与最小的数之差为 1。

给定一个长度为 n 的整数数组 a,要求找到所有符合以下条件的三元组 (i, j, k)

  • 0<=i<j<k<n0 <= i < j < k < n
  • max(a[i],a[j],a[k])min(a[i],a[j],a[k])=1max(a[i],a[j],a[k])−min(a[i],a[j],a[k])=1

也就是说,我们需要找到三元组 (i, j, k),它们的最大值与最小值之差为 1。可以通过枚举每一个三元组的方法来求解,但显然这样做的时间复杂度较高,最多为 O(n3)O(n^3),不适合在大规模数据中执行。为了提高效率,我们需要进行优化。

问题关键

要满足题目条件,三元组的三个元素 a[i], a[j], a[k] 必须包含两个不同的数,且这两个数之差为 1。具体来说:

  • 假设这两个不同的数是 xx+1,那么三元组中的每一个元素必须是 xx+1

所以,问题转化为:找出数组中含有数字 x 和数字 x+1 的所有可能三元组。

优化思路

  1. 统计数字的频率: 我们可以通过统计数组中每个数字出现的次数来减少不必要的计算。然后对于每对相邻的数字 xx+1,我们可以从这些数字中选择三元组。

  2. 组合计数: 对于每对 xx+1,我们可以计算出它们的出现次数,假设 x 出现的次数为 countx\text{count}_x,而 x+1 出现的次数为 countx+1\text{count}_{x+1},那么可以通过组合选择的方式选择 3 个数。具体地,从这些数字中选择一个三元组时,可以从 x 中选 1 个,x+1 中选 2 个,或者从 x 中选 2 个,x+1 中选 1 个。

    组合数可以通过公式 C(n,k)=n!k!(nk)!C(n, k) = \frac{n!}{k!(n-k)!} 计算,其中 C(n,k)C(n, k) 表示从 nn 个元素中选择 kk 个元素。

计算三元组的方法

  1. 统计数组中每个数字的出现次数。
  2. 对于每对相邻数字 xx+1,计算所有可能的三元组数量。
  3. 结果对 109+710^9 + 7 取模。

代码实现

#include <iostream>
#include <vector>
#include <unordered_map>
using namespace std;

const int MOD = 1e9 + 7;

// 计算组合 C(n, 3)
long long C3(long long n) {
    if (n < 3) return 0;
    return (n * (n - 1) * (n - 2)) / 6;
}

// 计算组合 C(n, 2)
long long C2(long long n) {
    if (n < 2) return 0;
    return (n * (n - 1)) / 2;
}

// 计算组合 C(n, 1)
long long C1(long long n) {
    return n;
}

long long solution(vector<int>& a) {
    unordered_map<int, long long> count;
    
    // 统计每个数字的出现次数
    for (int num : a) {
        count[num]++;
    }
    
    long long result = 0;

    // 遍历所有的数字,处理相邻的数字 x 和 x + 1
    for (auto& entry : count) {
        int x = entry.first;
        long long cnt_x = entry.second;
        if (count.find(x + 1) != count.end()) {
            long long cnt_x_plus_1 = count[x + 1];
            
            // 选择三元组的方式有两种:
            // 1. 选择1个 x 和 2个 x+1
            result = (result + C1(cnt_x) * C2(cnt_x_plus_1)) % MOD;
            // 2. 选择2个 x 和 1个 x+1
            result = (result + C2(cnt_x) * C1(cnt_x_plus_1)) % MOD;
        }
    }

    return result;
}

int main() {
    vector<int> a1 = {2, 2, 3, 1};
    vector<int> a2 = {1, 3, 2, 2, 1};
    vector<int> a3 = {1, 3, 2, 2, 1, 2};
    
    cout << solution(a1) << endl;  // 2
    cout << solution(a2) << endl;  // 5
    cout << solution(a3) << endl;  // 12

    return 0;
}

代码解析

  1. 统计每个数字的频率: 我们使用了一个哈希表 count 来统计数组中每个数字的出现次数。遍历数组一次,时间复杂度为 O(n)O(n)
  2. 计算组合数: 使用了函数 C3C2C1 分别计算三元组的组合数。C3 用于计算从 nn 个数字中选择 3 个的组合数,C2 用于选择 2 个,C1 用于选择 1 个。
  3. 遍历相邻数字: 对于每对相邻的数字 xxx+1x+1,计算从这两个数字中选出三元组的所有可能方式,并累加到结果中。
  4. 模取运算: 由于可能的结果很大,我们每次计算时都对结果进行取模 109+710^9 + 7,以避免溢出。

复杂度分析

  • 时间复杂度

    • 遍历数组统计频率需要 O(n)O(n)的时间。
    • 遍历哈希表中的每一对相邻数字 xxx+1x+1 需要 O(n)O(n) 的时间,最坏情况下哈希表的大小是 O(n)O(n)
    • 所以总体时间复杂度是 O(n)O(n)
  • 空间复杂度

    • 需要一个哈希表来存储每个数字的出现次数,空间复杂度是 O(n)O(n)