【算法偶遇】某团面试题1

45 阅读7分钟

问题

在xhs上看到某团面试这么一个问题: 给定长度n的数组a,求1<=i<j<k<=n,且满足a_j>a_k>a_i的三元组(i,j,k)的数目。

但是帖子上没有给出n的范围,根据贴主的描述,应该是给出n^2* log(n)的算法就可以通过面试题目,然而超过n^2的复杂度并不是计数问题该有的复杂度,经过思考,我给出n* log(n)* log(n)的算法,足以通过n=1e5。

测试已放到洛谷上:www.luogu.com.cn/problem/U60…

分析

对于逆序对问题,常用树状数组,这里是三元组,所以考虑使用CDQ分治。

首先对数组中所有元素进行离散化,便于使用树状数组计数,此时数组中的元素在1到n间。

CDQ分治后,令左右区间都为降序排序:

情况1:i,j在左区间,k在右。首先将右区间分桶,比如右区间为4,3,3,2,那么1-2是第一个桶(从1开始是因为所有数组中的元素都在1-n),3-3是第二个,第三个桶是4到3,所以桶里恒为空,第四个桶4-4。用一个指针在左区间遍历j,一个指针在右区间遍历k,如果a[j]>a[k],将a[j]的原下标(从大到小排序前下标)a[k]所在的桶的下标组成pair并记录下来。之后根据pair的第一关键字也就是原下标从小到大排序,然后按照原顺序从左向右遍历左区间,每次在树状数组上更新的时候,向这个数所在的桶更新,要统计答案时统计答案。

例如:

1,3,5,4,2,最后一层合并时:5,3,1 | 4,2。首先分桶,1-2是桶1,3-4是桶2。起初左指针在5,右指针在4,5>4,记录(3,2),3是5的原下标,2是因为4在第二个桶,然后左指针来到3,3<4,所以右指针来到2,3>2,记录(2,1),然后左指针来到1,1<2,右指针向右,右指针出界了,左右指针遍历结束。

我们记录的要统计的对是(3,2)和(2,1),根据第一关键字排序后是:(2,1),(3,2)。

之后按照原顺序1,3,5遍历左区间,1所属桶是1,桶1增加1,到3的时候,因为3是第二个数,和(2,1)的第一维匹配了,所以统计答案,第二维度是x的话,这里要统计的是第1个桶到第1个桶的前缀和,第1个桶到第2个桶的前缀和,,,,第1个桶到第x个桶的前缀和,这些前缀和们的和,此时是第一个桶到第一个桶的前缀和,之后3属于第二个桶,桶2增加1,然后到5了,根据(3,2),要将第1到第1个桶的前缀和,以及第1到第2个桶的前缀和计入答案。之后5不在任何一个桶内,不改变桶的值。

用常见技巧求前缀和的前缀和:例如到第k个桶,设桶为数组b,那么b1+(b1+b2)+(b1+b2+b3)+...+(b1+..+bk)=(k+1)∑bi-∑i* bi,因此用树状数组维护桶元素b_i和i* b_i即可。

情况2:i在左区间,j,k在右区间,这种情况直接按照原顺序,从右到左遍历右区间的j处理就可以了,想法较为自然,不做更多提示。

测试数据生成

/*  
模板符合C++98标准  
upd:25.08.16  
*/  
//输入输出相关头文件  
#include<iostream>  
#include<cstdio>  
//数据类型和STL  
#include<cstring>  
#include<string>  
#include<cctype>  
#include<set>  
#include<map>  
#include<unordered_map>  
#include<queue>  
#include<vector>  
#include<bitset>  
#include<stack>  
#include<utility>  
//其他  
#include<cmath>  
#include<algorithm>  
#include<cstdlib>  
#include<iomanip>  
#include<climits>  
#include<limits>  
#include<sstream>  
using namespace std;  
typedef long long ll;  
typedef unsigned long ul;  
typedef unsigned long long ull;  
  
#define fastio ios::sync_with_stdio(0);cin.tie(0);  
#define fin freopen("D:/in.txt","r",stdin);  
#define fout freopen("D:/out.txt","w",stdout);  
  
const ll kases=10;  
  
int main()  
{  
    freopen("D:/in.txt","w",stdout);  
    srand(time(NULL));  
      
    cout<<kases<<"\n\n";  
      
    for(ll kase=1;kase<=kases;kase++) {  
          
        ll n=1e5;  
        cout<<n<<"\n";  
        for(ll i=1;i<=n;i++) cout<<rand()+1<<" ";          
        cout<<"\n\n";  
    }  
      
  
      
    return 0;  
}

暴力法对拍程序

/*  
模板符合C++98标准  
upd:25.08.16  
*/  
//输入输出相关头文件  
#include<iostream>  
#include<cstdio>  
//数据类型和STL  
#include<cstring>  
#include<string>  
#include<cctype>  
#include<set>  
#include<map>  
#include<unordered_map>  
#include<queue>  
#include<vector>  
#include<bitset>  
#include<stack>  
#include<utility>  
//其他  
#include<cmath>  
#include<algorithm>  
#include<cstdlib>  
#include<iomanip>  
#include<climits>  
#include<limits>  
#include<sstream>  
using namespace std;  
typedef long long ll;  
typedef unsigned long ul;  
typedef unsigned long long ull;  
  
#define fastio ios::sync_with_stdio(0);cin.tie(0);  
#define fin freopen("D:/in.txt","r",stdin);  
#define fout freopen("D:/out.txt","w",stdout);  
  
const ll maxn=2e5+5;  
ll a[maxn];  
ll n;  
  
int main()  
{  
    freopen("D:\\Users\\Cody\\Desktop\\程序实验\\MT1\\out_naive.txt","w",stdout);  
      
    fin;  
    fastio;  
    ll kases;  
    cin>>kases;  
    for(ll kase=1;kase<=kases;kase++) {  
        cin>>n;  
        for(ll i=1;i<=n;i++) cin>>a[i];  
          
        ll ans=0;  
        for(ll i=1;i<=n;i++) {  
            for(ll j=1;j<=n;j++) {  
                for(ll k=1;k<=n;k++)  {  
                    if(i<j && j<k && a[j]>a[k] && a[k]>a[i]) ans++;  
                }  
            }  
        }  
        cout<<ans<<"\n";          
    }      
    return 0;  
}

参考程序

/*  
模板符合C++98标准  
upd:25.08.16  
*/  
//输入输出相关头文件  
#include<iostream>  
#include<cstdio>  
//数据类型和STL  
#include<cstring>  
#include<string>  
#include<cctype>  
#include<set>  
#include<map>  
#include<queue>  
#include<vector>  
#include<bitset>  
#include<stack>  
#include<utility>  
//其他  
#include<cmath>  
#include<algorithm>  
#include<cstdlib>  
#include<iomanip>  
#include<climits>  
#include<limits>  
#include<sstream>  
using namespace std;  
typedef long long ll;  
  
#define fastio ios::sync_with_stdio(0);cin.tie(0);  
#define fin freopen("D:/in.txt","r",stdin);  
#define fout freopen("D:/out.txt","w",stdout);  
#define lowbit(x) (x&(-x))  
  
const ll maxn=2e5+5;  
  
ll n,ans,ans1,ans2;  
ll ori[maxn];  
vector<pair<ll,ll>> to_calc;      
vector<ll> bucket_bdy;    //bucket_boundary  
  
struct number {  
    ll id,v;  
    bool operator < (const number &rhs) const {  
        return v<rhs.v;  
    }  
}ord[maxn];   //CDQ分治中用到的从大到小ordered的数组  
  
struct BIT {  
    ll val[maxn+5];  
    ll siz;  
    vector<ll> touched;  
    void init(ll siz) {  
        for(ll i=0;i<=siz;i++) val[i]=0;  
        touched.clear();  
        this->siz=siz;  
    }  
    BIT(ll siz) {init(siz);}  
    BIT() {}  
      
    void update(ll x,ll d) {  
        if(x<=0return ;      
        while(x<=siz) {  
            if(val[x]==0) touched.emplace_back(x);  
            val[x]+=d;  
            x+=lowbit(x);  
        }  
    }  
      
    ll sum(ll x) {  
        ll res=0;  
        while(x>0) {  
            res+=val[x];  
            x-=lowbit(x);  
        }  
        return res;  
    }  
      
    void reset() {  
        for(ll idx:touched) val[idx]=0;  
        touched.clear();  
    }  
}tree1(maxn),tree1_1(maxn),tree2(maxn),tree3(maxn);  
  
void CDQ(ll L,ll R) {  
    if(L==R) return ;  
    ll mid=(L+R)>>1;  
    CDQ(L,mid);  
    CDQ(mid+1,R);  
    //统计答案  
    if(R>=L+2) {  
        //i,j在左,k在右  
        //子问题A  
        to_calc.clear();    //bucket_boundary  
        //to_calc.reserve(R-L+1);  
        bucket_bdy.clear();  
        //bucket_bdy.reserve(R-L+1);  
        for(ll k=mid+1;k<=R;k++) bucket_bdy.emplace_back(ord[k].v-1);  
        stable_sort(bucket_bdy.begin(),bucket_bdy.end());  
        ll j=L,k=mid+1;  
        while(j<=mid && k<=R) {  
            if(ord[j].v>ord[k].v) {  
                ll bid=upper_bound(bucket_bdy.begin(),bucket_bdy.end(),  
                                   ord[k].v-1)-bucket_bdy.begin()-1+1;  
                to_calc.emplace_back(ord[j].id,bid);  
                j++;  
            }else {  
                k++;  
            }  
        }  
        stable_sort(to_calc.begin(),to_calc.end());  
          
        //TODO:输出to_calc  
        /*puts("to_calc:");  
        for(auto x:to_calc) {  
            cout<<x.first<<" "<<x.second<<"\n";  
        }  
        puts("");*/  
          
        ll nxt_match=0,bucket_cnt=bucket_bdy.size();  
        tree1.reset();  
        tree1_1.reset();  
        for(ll i=L;i<=mid && nxt_match<to_calc.size();i++) {  
            if(i==to_calc[nxt_match].first) {  
                ll bid=to_calc[nxt_match].second;  
                ans1+=(bid+1)*tree1.sum(bid)-tree1_1.sum(bid);  
                nxt_match++;  
            }  
            ll x=ori[i];  
            ll buck_id=lower_bound(bucket_bdy.begin(),bucket_bdy.end(),x)-bucket_bdy.begin()+1;  
            if(buck_id>bucket_cnt) continue;  
            tree1.update(buck_id,1);  
            tree1_1.update(buck_id,buck_id);  
        }              
          
        //i在左,k,j在右  
        //子问题B  
        if(R>L+2) {  
            tree2.reset();  
            tree3.reset();  
            for(ll i=L;i<=mid;i++) tree2.update(ori[i],1);  
            for(ll j=R;j>=mid+1;j--)  {  
                ans2+=tree3.sum(ori[j]-1);      
                tree3.update(ori[j],tree2.sum(ori[j]-1));  
            }              
        }  
          
    }  
    //归并排序  
    ll idx=L,l=L,r=mid+1;  
    static number t[maxn];  
    while(l<=mid || r<=R) {  
        ll vfd_l=l,vfd_r=r;    //var for debug  
        if(l>mid) {  
            t[idx]=ord[r];  
            r++;  
        }  
        else if(r>R) {  
            t[idx]=ord[l];  
            l++;  
        }  
        else if(ord[l].v>ord[r].v) {  
            t[idx]=ord[l];  
            l++;  
        }else {  
            t[idx]=ord[r];  
            r++;  
        }  
        idx++;  
    }  
    for(ll i=L;i<=R;i++) ord[i]=t[i];  
    /*  
    printf("区间[%lld,%lld]的排序结果:\n",L,R);  
    for(ll i=L;i<=R;i++) cout<<ord[i].v<<" ";  
    puts("");  
    for(ll i=L;i<=R;i++) cout<<ord[i].id<<" ";  
    puts("");*/  
}  
  
int main()  
{  
    freopen("D:\\Users\\Cody\\Desktop\\程序实验\\MT1\\out_opti.txt","w",stdout);  
    fin;  
    fastio;  
      
    ll kases;cin>>kases;  
    for(ll kase=1;kase<=kases;kase++) {  
        cin>>n;  
        for(ll i=1;i<=n;i++) cin>>ori[i];  
          
        //离散化到ord数组  
        ll t[maxn]; //临时用于离散化的数组  
        for(ll i=1;i<=n;i++) {  
            t[i]=ori[i];  
        }  
        stable_sort(t+1,t+1+n);  
        ll t1=unique(t+1,t+1+n)-(t+1);  
        //printf("t1=%lld\n",t1);  
        for(ll i=1;i<=n;i++) {  
            ord[i].id=i;  
            ord[i].v=lower_bound(t+1,t+1+t1,ori[i])-t;  
        }  
        for(ll i=1;i<=n;i++) {  
            ori[i]=ord[i].v;  
        }  
          
        //TODO:离散化测试  
        //puts("离散化测试");  
        //for(ll i=1;i<=n;i++) cout<<ori[i]<<" ";  
          
        ans1=ans2=0;  
        CDQ(1,n);  
          
        //TODO:归并排序测试  
        //puts("归并排序测试");  
        //for(ll i=1;i<=n;i++) cout<<ord[i].v<<" ";  
          
        //cout<<ans1<<" "<<ans2<<"\n";  
        cout<<ans1+ans2<<"\n";          
    }  
      
      
      
    return 0;  
}