本文已参与[新人创作礼]活动,一起开启掘金创作之路
题目描述
这是4月16日代码源div2的每日一题。
喵喵序列 - 题目 - Daimayuan Online Judge
知识点:线段树
题目描述
给定一个含有 n 个整数的序列 a1,a2,…an,三个数 i,j,k 是可爱的当且仅当 i<j<k 且 ai<a。
请你求出有多少组 i,j,k 是可爱的。
输入格式
第 1 行一个整数 n 表示序列元素个数。
第 2 行 n 个整数分别表示 a1,a2,…an。
输出格式
一行一个整数,表示所求数量。
样例输入
5
1 2 2 3 4
样例输出
7
样例说明
满足条件的有:(1,2,3),(1,2,4),(1,2,3),(1,2,4),(1,3,4),(2,3,4)(2,3,4),(2,3,4),共 7 个。
数据范围
对于全部数据,有 1≤n≤3×10^4,0≤ai<2^63。
\
问题解析
这题我用的是线段树求逆序对的解法。
这种方法是求二元组的,即左边有多少个小于它的数(具体可以看CodeForces: A. Inversions - 掘金 (juejin.cn))。但这题转成了三元组,所以要稍稍的加一点改动。
我们可以再准备一个线段树f2,是基于f1计算的(f1是计算二元组的线段树,即上面那一题的数)。我们从左往右遍历数组,f1记录左边有多少个数大于我们当前遍历到的数,这样就是一个二元组了。三元组就是基于二元组上再加一个数,所以f2就是记录f1数组中区间1~a[i-1]的区间和,这个区间和就说明我们当前的元素可以和前面的数组合出多少三元组。每次遍历完一个数后,在f1上对应的位置加上1(因为有可能有重复的元素出现)。
文字描述可能太过抽象,我们举个例子来写:1 2 2 3 4。
-
线段树f1是0 0 0 0 0,f2是0 0 0 0 0,然后我们遍历到第一个元素1,看f2中1到a[i]-1的区间和,为0(其实可以特判一下1就直接不用算区间和了),然后算f1的区间和,这里是看左边有多少个数小于1,当然也是0,然后我们在f1中把1的位置加上1。此时记录的三元组为0。
-
线段树f1是1 0 0 0 0,f2是0 0 0 0 0,然后我们遍历到第二个元素2,看f2中1到a[i]-1的区间和,为0,然后算f1的区间和,是1,所以f2上对应的位置+1,然后我们在f1中把2的位置加上1。此时记录的三元组为0。
-
线段树f1是1 1 0 0 0,f2是0 1 0 0 0,然后我们遍历到第二个元素2,看f2中1到a[i]-1的区间和,为0,然后算f1的区间和,是1,所以f2上对应的位置+1,然后我们在f1中把2的位置加上1。此时记录的三元组为0。
-
线段树f1是1 2 0 0 0,f2是0 2 0 0 0,然后我们遍历到第三个元素3,看f2中1到a[i]-1的区间和,为2,然后算f1的区间和,是3,所以f2上对应的位置+3,然后我们在f1中把3的位置加上1。此时记录的三元组为2。
-
线段树f1是1 2 1 0 0,f2是0 2 3 0 0,然后我们遍历到第二个元素2,看f2中1到a[i]-1的区间和,为5,然后算f1的区间和,是4,所以f2上对应的位置+4,然后我们在f1中把4的位置加上1。此时记录的三元组为5+2=7。
-
遍历结束,三元组为7,如果还有下一个元素,且是大于4的元素,那三元组会再加上f2的区间和:5,即三元组为12。
这个方法不止可以算二元三元组可以一直套下去,四元五元十元都可以,只要多加几棵树就行(超不超时就不知道了)
还有一点就是,这题没有说这里的数是在1~n的,所以我们要把数据离散化一下,变成紧凑的序列。
AC代码
#include<iostream>
using namespace std;
#include<vector>
#include<algorithm>
#include<math.h>
#include<set>
#include<numeric>
#include<string>
#include<string.h>
#include<iterator>
#include<map>
#include<unordered_map>
#include<stack>
#include<list>
#include<queue>
#include<iomanip>
#define endl '\n';
typedef long long ll;
typedef pair<int, int> PII;
const int N = 300050;
unordered_map<ll, ll>mymap;
void revise(ll k, ll l, ll r, ll x, vector<ll>& f, ll v)
{
if (l == r)
{
f[k] += v;
return;
}
int m = (l + r) / 2;
if (x <= m)revise(k + k, l, m, x, f, v);
else
revise(k + k + 1, m + 1, r, x, f, v);
f[k] = (f[k + k] + f[k + k + 1]);
}
ll calc(ll k, ll l, ll r, ll x, ll y, vector<ll>& f)
{
if (l == x && y == r)
{
return f[k];
}
int m = (l + r) / 2;
if (y <= m)return calc(k + k, l, m, x, y, f);
else
if (x > m)return calc(k + k + 1, m + 1, r, x, y, f);
else
{
return (calc(k + k, l, m, x, m, f) + calc(k + k + 1, m + 1, r, m + 1, y, f));
}
}
int main()
{
ll n, ans = 1, res = 0;
cin >> n;
vector<ll>a(n + 1), b, f1(4 * n), f2(4 * n);
for (int i = 1; i <= n; i++)cin >> a[i];
b = a;
sort(b.begin(), b.end());
mymap[b[1]] = ans++;
//离散化操作
for (int i = 2; i <= n; i++)
{
if (b[i] != b[i - 1])mymap[b[i]] = ans++;
}
for (int i = 1; i <= n; i++)
{
if (mymap[a[i]] != 1)res += calc(1, 1, n, 1, mymap[a[i]] - 1, f2);
if (mymap[a[i]] != 1)
{
ll ans = calc(1, 1, n, 1, mymap[a[i]] - 1, f1);
revise(1, 1, n, mymap[a[i]], f2, ans);
}
revise(1, 1, n, mymap[a[i]], f1, 1);
}
cout << res << endl;
return 0;
}