【学海拾遗】数据规模10^5,如何求一个数列中三角形的总数?

182 阅读4分钟

Offer 驾到,掘友接招!我正在参与2022春招打卡活动,点击查看活动详情

题目——Matrix53的数列 - Hard 的题解

Matrix53 和 Marvolo 发财啦,因为他们上次解决了 Medium 难度的数列问题,获得了“世纪难题”委员会的奖金。

“世纪难题”委员会惊讶于 Matrix53 和 Marvolo 的解题能力,对他们大加称赞。

“计算姬”委员会知道了他们的研究成果,想让他们对 Easy 难度的数列问题的时间复杂度做进一步的优化。

给定一正整数数列,统计出数列中的三角形的总数

数列中的三角形:若有三个数处于数列中三个不同的位置,且这三个数可以作为某一个三角形的三条边,则这三个数构成一个“数列中的三角形”

输入

第一行为数列的长度 n

第二行为 n 个正整数,代表数列中的元素 aia_i

输出

输出一个整数,表示这个数列中的三角形的总数

输入样例

4
1 2 2 2

输出样例

4

数据范围

3n105,3≤n≤10^5,
1ai1051≤a_i≤10^5

思路

本题的大意是在数列中选3个数,求能让3个数满足三角形性质的方案数.

首先,有如下假定:

max:数列所有数中的最大值max:数列所有数中的最大值
P(x):数列中x的个数,0xmaxP(x):数列中x的个数,0 \leq x\leq max

根据题意,max105max\leq 10^5. 由于所有项都是正数,显然P(0)=0P(0)=0ii可以为00,目的是方便接下来的计算.

P(i)P(i)可以很容易得出,输入的时候读到xx就令P(x)P(x)++就行.

设任意k[1,max]k\isin [1,max],我们要找到两个数列中的项,其长度为 iijj ,且满足iki\leq k, jkj\leq k.

如果我们要对所有的k寻找所有满足i+j>ki+j>k的方案数然后加起来,有如下公式(本题用不到此公式,大概看一下就行):

Ans=k=3max(i+j>ki<kj<ki/=j(P(i)P(j))P(k)+2i>ki<k(P(i)2)P(k)+(P(k)3)+...)Ans=\sum_{k=3}^{max}\Bigg(\sum_{\substack{i+j>k\\i<k\\j<k\\i\mathrlap{\,/}{=}j}}\bigg(P(i)*P(j)\bigg)*P(k)+\sum_{\substack{2i>k\\i<k}}\dbinom{P(i)}{2}*P(k)+\dbinom{P(k)}{3}+...\Bigg)

表达式十分复杂,因为iijj之和要大于kk,各自又要小于等于kk,还要排除很多的重复情况等等. 因此我们可以逆向思考,找到不能组成三角形的总方案数SS,再用总方案数RR减去SS就行了. 即对每个kk,算出i+jki+j\leqslant k的方案数. 这时i+jki+j\leqslant k就蕴含了i<ki<kj<kj<k,因此排除了重复的情况,式子更为简单. 先假定:

Q(x):数列中挑出两个数字之和为x的方案数,0xmaxQ(x):数列中挑出两个数字之和为x的方案数,0\leq x\leq max

那么易得,对于任意k,设i+j=k0i+j=k_0i+jki+j\leq k即意味着k0kk_0\leq k. 故把所有满足k0kk_0\leq kQ(k0)Q(k_0)加起来然后乘以P(k)P(k)就得到了此kk下无法构成三角形的方案数(别忘了kk一直是最大边),然后把所有的k的方案加起来即得到SS. 则

不能组成三角形的方案数S=k=3max(k0=2kQ(k0)P(k))(*)不能组成三角形的方案数S=\sum_{k=3}^{max}\Bigg(\sum_{k_0=2}^{k}Q(k_0)*P(k)\Bigg) \tag{*}
从数列中挑三个数的总方案数R=k=3max(k12)=max(max1)(max2)6从数列中挑三个数的总方案数R=\sum_{k=3}^{max}\dbinom{k-1}{2}=\frac{max*(max-1)*(max-2)}{6}
Ans=RSAns=R-S

在知道了Q(x)Q(x)的前提下,我们可以通过前缀和的方式来计算式子()(*),所需时间O(max)O(max).

最后就是Q(x)Q(x)怎么求了. 根据Q(x)Q(x)的定义,我们要找到所有满足i+j=k0i+j=k_0P(i)P(i)P(j)P(j),然后将它们相乘后全部相加. 不考虑重复情况时,设如下式子:

c(k0)=i=0k0(P(i)P(k0i))c(k_0)=\sum_{i=0}^{k_0}\bigg(P(i)*P(k_0-i) \bigg)

这时我们发现,上式不就是多项式乘法的样子吗!设多项式A(x)=k0=0maxP(k0)xk0A(x)=\sum_{k_0=0}^{max}P(k_0)x^{k_0}B(x)=k0=0maxc(k0)xk0B(x)=\sum_{k_0=0}^{max}c(k_0)x^{k_0},计算c(k0)c(k_0)即等价于计算B(x)=A(x)A(x)B(x)=A(x)*A(x). 这时就可以利用FFT简化复杂度了.

算出c(k0)c(k_0)后我们要去掉重复情况,然后才能得到Q(k0)Q(k_0). 我们发现由于是从0遍历到k0k_0,因此每个数的方案算了两次. 并且当k0k_0为偶数时有可能会i=k0i=k0/2i=k_0-i=k_0/2,这时不能简单的直接P(i)P(k0i)P(i)*P(k_0-i),而要计算组合数(P(k0/2)2)=P(k0/2)(P(k0/2)1)/2\dbinom{P(k_0/2)}{2}=P(k_0/2)*(P(k_0/2)-1)/2. 因此Q(k0)Q(k_0)应该如下:

Q(k0)={c(k0)2,k0为奇数c(k0)P(k0/2)22+(P(k0/2)2)=c(k0)P(k0/2)2,k0为偶数Q(k_0) = \begin{cases} \Large\frac{c(k_0)}{2}, &\text{$k_0$为奇数} \\\\ \Large\frac{c(k_0)-P(k_0/2)^2}{2}+\normalsize\dbinom{P(k_0/2)}{2}=\Large\frac{c(k_0)-P(k_0/2)}{2}, &\text{$k_0$为偶数} \end{cases}

至此所有的问题都已解决.

复杂度分析

读取输入并得到P(i)P(i)需要 O(n)O(n) ,通过P(i)P(i)FFT乘法计算得到c(i)c(i)需要O(maxlg(max))O(max*lg(max)),计算Q(i)Q(i)需要O(max)O(max),计算RR需要O(1)O(1),最后O(1)O(1)得到答案Ans. 最终的复杂度为O(maxlg(max))O(max*lg(max)).

代码

#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
#include<cstring>
using namespace std;
using db=double;
using ll=long long;
const double pi=acos(-1);
const int maxn=100000+5;
template <class numType>struct Complex{
    numType a;
    numType b;
    Complex(numType x=0,numType y=0):a(x),b(y){}
    Complex<numType> operator*(Complex<numType> x){
        return {a*x.a-b*x.b,a*x.b+b*x.a};
    }
    Complex<numType> operator*(numType x){
        return {a*x,b*x};
    }
    Complex<numType> operator+(Complex<numType> x){
        return {a+x.a,b+x.b};
    }
    Complex<numType> operator-(Complex<numType> x){
        return {a-x.a,b-x.b};
    }
    Complex<numType> operator/(numType x){
        return {a/x,b/x};
    }
};
using Cp=Complex<double>;
int rev[maxn*3];
void initRev(int k){
//k为总位数,n=2时k=1
    rev[0]=0;
    int len=1<<k;
    for(int i=0;i<len;i++)
    rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
    //此处为递归算法
}
void FFT(Cp a[],int n,int f){
    //n必须为2的次方
    //必须提前算好rev数组
    for(int i=0;i<n;i++)
        if(i<rev[i])swap(a[i],a[rev[i]]);
    int len=n;
    for(n=2;n<=len;n*=2){
        Cp wn={cos(f*2*pi/n),sin(f*2*pi/n)};
        for(int i=0;i<len;i+=n){
            Cp w=1;
            for(int j=0;j<n/2;j++){
                Cp p=a[i+j];//p是原来的y[0],意为双数
                Cp q=w*a[i+j+n/2];//单数
                a[i+j+n/2]=p-q;
                a[i+j]=p+q;
                w=w*wn;
            }
        }
    }
    if(f==-1)for(int i=0;i<len;i++)a[i]=a[i]/(db)len;
}
ll P[maxn+5];
Cp arr[maxn*4];
ll Q[maxn+5];
int main(){
    int n;
    cin>>n;
    int mx=0;//对应上面分析的max
    for(int i=0;i<n;i++){
        int x;
        cin>>x;
        P[x]++;
        mx=max(x,mx);
    }
    //下面开始计算多项式乘法
    for(int i=1;i<=mx;i++){
        arr[i].a=P[i];
    }
    int len=1,lgn=0;
    while(len<mx*2)len*=2,lgn++;
    initRev(lgn);
    FFT(arr,len,1);
    for(int i=0;i<len;i++)
        arr[i]=arr[i]*arr[i];
    FFT(arr,len,-1);
    for(int i=0;i<=mx;i++){
        Q[i]=arr[i].a+0.5;
        //算出c(i)的同时把c(i)转化为Q(i)
        if(i%2==0)
            Q[i]=(Q[i]-P[i/2])/2;
        else 
            Q[i]/=2;
    }
    ll R=(ll)n*(n-1)*(n-2)/6;
    ll tmp=Q[1]+Q[2];
    ll sum=0;
    for(int i=3;i<=maxn;i++){
        tmp+=Q[i];
        //tmp是从Q(0)加到Q(i)
        sum+=(ll)P[i]*tmp;
    }
    cout<<R-sum;
}