问题描述
小C、小U 和小R 三个好朋友喜欢做一些数字谜题。这次他们遇到一个问题,给定一个长度为n
的数组a
,他们想要找出符合特定条件的三元组 (i, j, k)
。具体来说,三元组要满足 0 <= i < j < k < n
,并且 max(a[i], a[j], a[k]) - min(a[i], a[j], a[k]) = 1
,也就是说,最大值与最小值之差必须为1。
他们决定请你帮忙编写一个程序,计算符合这个条件的三元组数量。
测试样例
样例1:
输入:
a = [2, 2, 3, 1]
输出:2
样例2:
输入:
a = [1, 3, 2, 2, 1]
输出:5
样例3:
输入:
a = [1, 3, 2, 2, 1, 2]
输出:12
思路解析
本题可以分解为找出数组中所有的三元组,其中最大的数与最小的数之差为 1。
给定一个长度为 n
的整数数组 a
,要求找到所有符合以下条件的三元组 (i, j, k)
:
也就是说,我们需要找到三元组 (i, j, k)
,它们的最大值与最小值之差为 1。可以通过枚举每一个三元组的方法来求解,但显然这样做的时间复杂度较高,最多为 ,不适合在大规模数据中执行。为了提高效率,我们需要进行优化。
问题关键
要满足题目条件,三元组的三个元素 a[i]
, a[j]
, a[k]
必须包含两个不同的数,且这两个数之差为 1。具体来说:
- 假设这两个不同的数是
x
和x+1
,那么三元组中的每一个元素必须是x
或x+1
。
所以,问题转化为:找出数组中含有数字 x
和数字 x+1
的所有可能三元组。
优化思路
-
统计数字的频率: 我们可以通过统计数组中每个数字出现的次数来减少不必要的计算。然后对于每对相邻的数字
x
和x+1
,我们可以从这些数字中选择三元组。 -
组合计数: 对于每对
x
和x+1
,我们可以计算出它们的出现次数,假设x
出现的次数为 ,而x+1
出现的次数为 ,那么可以通过组合选择的方式选择 3 个数。具体地,从这些数字中选择一个三元组时,可以从x
中选 1 个,x+1
中选 2 个,或者从x
中选 2 个,x+1
中选 1 个。组合数可以通过公式 计算,其中 表示从 个元素中选择 个元素。
计算三元组的方法
- 统计数组中每个数字的出现次数。
- 对于每对相邻数字
x
和x+1
,计算所有可能的三元组数量。 - 结果对 取模。
代码实现
#include <iostream>
#include <vector>
#include <unordered_map>
using namespace std;
const int MOD = 1e9 + 7;
// 计算组合 C(n, 3)
long long C3(long long n) {
if (n < 3) return 0;
return (n * (n - 1) * (n - 2)) / 6;
}
// 计算组合 C(n, 2)
long long C2(long long n) {
if (n < 2) return 0;
return (n * (n - 1)) / 2;
}
// 计算组合 C(n, 1)
long long C1(long long n) {
return n;
}
long long solution(vector<int>& a) {
unordered_map<int, long long> count;
// 统计每个数字的出现次数
for (int num : a) {
count[num]++;
}
long long result = 0;
// 遍历所有的数字,处理相邻的数字 x 和 x + 1
for (auto& entry : count) {
int x = entry.first;
long long cnt_x = entry.second;
if (count.find(x + 1) != count.end()) {
long long cnt_x_plus_1 = count[x + 1];
// 选择三元组的方式有两种:
// 1. 选择1个 x 和 2个 x+1
result = (result + C1(cnt_x) * C2(cnt_x_plus_1)) % MOD;
// 2. 选择2个 x 和 1个 x+1
result = (result + C2(cnt_x) * C1(cnt_x_plus_1)) % MOD;
}
}
return result;
}
int main() {
vector<int> a1 = {2, 2, 3, 1};
vector<int> a2 = {1, 3, 2, 2, 1};
vector<int> a3 = {1, 3, 2, 2, 1, 2};
cout << solution(a1) << endl; // 2
cout << solution(a2) << endl; // 5
cout << solution(a3) << endl; // 12
return 0;
}
代码解析
- 统计每个数字的频率: 我们使用了一个哈希表
count
来统计数组中每个数字的出现次数。遍历数组一次,时间复杂度为 。 - 计算组合数: 使用了函数
C3
、C2
和C1
分别计算三元组的组合数。C3
用于计算从 nn 个数字中选择 3 个的组合数,C2
用于选择 2 个,C1
用于选择 1 个。 - 遍历相邻数字: 对于每对相邻的数字 和 ,计算从这两个数字中选出三元组的所有可能方式,并累加到结果中。
- 模取运算: 由于可能的结果很大,我们每次计算时都对结果进行取模 ,以避免溢出。
复杂度分析
-
时间复杂度:
- 遍历数组统计频率需要 的时间。
- 遍历哈希表中的每一对相邻数字 和 需要 的时间,最坏情况下哈希表的大小是 。
- 所以总体时间复杂度是 。
-
空间复杂度:
- 需要一个哈希表来存储每个数字的出现次数,空间复杂度是 。