合法三元组数量计算| 豆包MarsCode AI刷题

222 阅读5分钟

为了实现这个功能,我们需要找到满足条件的三元组 (i, j, k),使得:

  1. 0 <= i < j < k < n,即 ijk 是不同的,并且按顺序排列。
  2. max(a[i], a[j], a[k]) - min(a[i], a[j], a[k]) = 1,即三元组中的最大值与最小值之差为 1。

思路:

  1. 我们可以直接遍历数组中的所有可能的三元组 (i, j, k)
  2. 对每个三元组,我们计算它们的最大值和最小值,检查是否满足 max(a[i], a[j], a[k]) - min(a[i], a[j], a[k]) == 1
  3. 如果满足条件,则计数。

优化建议:

我们可以优化一下这个过程,避免重复计算和遍历所有三元组。通过记录数组中每个元素的数量,可以快速判断符合条件的三元组。

代码实现:

cpp
Copy code
#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

long long solution(vector<int> a) {
    long long count = 0;
    int n = a.size();

    // 遍历所有可能的三元组 (i, j, k)
    for (int i = 0; i < n - 2; ++i) {
        for (int j = i + 1; j < n - 1; ++j) {
            for (int k = j + 1; k < n; ++k) {
                int maxVal = max({a[i], a[j], a[k]});
                int minVal = min({a[i], a[j], a[k]});
                if (maxVal - minVal == 1) {
                    count++;
                }
            }
        }
    }
    
    return count;
}

int main() {
    vector<int> a1 = {2, 2, 3, 1};
    vector<int> a2 = {1, 3, 2, 2, 1};
    vector<int> a3 = {1, 3, 2, 2, 1, 2};
    
    cout << (solution(a1) == 2) << endl; // 输出1表示正确
    cout << (solution(a2) == 5) << endl; // 输出1表示正确
    cout << (solution(a3) == 12) << endl; // 输出1表示正确
    
    return 0;
}

解释:

  1. solution 函数通过三个嵌套循环遍历所有可能的三元组 (i, j, k),并计算每个三元组的最大值和最小值。
  2. 如果最大值与最小值之差为1,则计数器 count 增加。
  3. main 函数中,我们通过比较返回值与预期值,来验证程序是否正确。

复杂度:

这个实现的时间复杂度是 O(n^3),因为我们有三重嵌套循环来遍历所有的三元组。对于大数组,这种算法可能会较慢。

改进方向:

如果数组的大小很大,我们可以考虑其他更高效的算法,例如通过哈希表或计数排序来优化,避免直接遍历所有三元组。

要优化这个问题,目标是尽量减少不必要的计算,避免三重循环导致的 O(n^3) 时间复杂度。我们可以考虑以下优化方法:

1. 使用计数法

通过统计每个数值在数组中的出现频率,可以大大减少计算量。具体来说,可以利用以下观察:

  • 我们只关心三元组中的最大值和最小值之差为 1。
  • 因此,只需统计数组中出现的数字,并查看哪些数字相差为 1。例如,如果 x 出现了 count[x] 次,且 x+1 出现了 count[x+1] 次,那么可能的三元组 (i, j, k) 就是从这两个数字中挑选出合适的三个数字来。

2. 通过两个数字的计数,找到符合条件的三元组

对于每个数字 x,检查数组中数字 xx+1 的组合。具体来说:

  • 如果数组中 x 出现了 cnt_x 次,x+1 出现了 cnt_x_plus_1 次,那么所有符合条件的三元组 (i, j, k) 可以通过这两个数字来构成。

3. 具体实现

我们通过以下步骤来优化代码:

  • 统计每个数字的出现次数。
  • 对每对相邻的数字(xx+1),计算所有符合条件的三元组数目。

优化后的代码:

cpp
Copy code
#include <iostream>
#include <vector>
#include <unordered_map>

using namespace std;

long long solution(vector<int> a) {
    long long count = 0;
    unordered_map<int, long long> freq;  // 记录每个数字的出现次数
    
    int n = a.size();
    
    // 统计每个数字的出现次数
    for (int i = 0; i < n; ++i) {
        freq[a[i]]++;
    }
    
    // 遍历所有可能的数字对 (x, x+1)
    for (auto& entry : freq) {
        int x = entry.first;
        if (freq.find(x + 1) != freq.end()) {
            long long cnt_x = entry.second;  // 数字 x 的出现次数
            long long cnt_x_plus_1 = freq[x + 1];  // 数字 x+1 的出现次数
            
            // 计算所有符合条件的三元组数目
            // 选择一个 x, 一个 x+1, 然后选择其他两个数字
            count += cnt_x * cnt_x_plus_1 * (cnt_x + cnt_x_plus_1 - 2) / 2;
        }
    }
    
    return count;
}

int main() {
    vector<int> a1 = {2, 2, 3, 1};
    vector<int> a2 = {1, 3, 2, 2, 1};
    vector<int> a3 = {1, 3, 2, 2, 1, 2};
    
    cout << (solution(a1) == 2) << endl; // 输出1表示正确
    cout << (solution(a2) == 5) << endl; // 输出1表示正确
    cout << (solution(a3) == 12) << endl; // 输出1表示正确
    
    return 0;
}

代码解释:

  1. 统计数字出现次数:我们通过 unordered_map 来统计每个数字在数组中出现的次数。
  2. 遍历相邻的数字对:我们只关心 xx+1 这样的数字对。
  3. 计算符合条件的三元组数目:对于每对 xx+1,如果它们都存在,我们通过组合数学公式计算满足条件的三元组数目。对于 cnt_xcnt_x_plus_1 次出现的 xx+1,可以组成的三元组数目是: cnt_x×cnt_x_plus_1×(cnt_x+cnt_x_plus_1−2)/2\text{cnt_x} \times \text{cnt_x_plus_1} \times (\text{cnt_x} + \text{cnt_x_plus_1} - 2) / 2cnt_x×cnt_x_plus_1×(cnt_x+cnt_x_plus_1−2)/2 这个公式计算的是从两个数字中选出 3 个数的方法数。

复杂度:

  • 时间复杂度O(n),因为我们只需要遍历一遍数组来统计频率,然后遍历频率表中的每个数字对进行计算,最多有 O(n) 个不同的数字。
  • 空间复杂度O(n),用于存储数字的频率。

这个优化版本大大减少了计算量,特别适合于数组较大的情况。