Nim Counting_abc212_h分析与解答

69 阅读6分钟

前置知识

前置知识:FWT(快速沃尔什变换),FWT相关性质。

先来看看FWT用于解决什么问题:

给两个长度相同且长度为2的n次方的序列(n>=0),序列下标从0开始到2^n-1,两序列为A和B,定义序列C,C[k]=求和(A[i]* B[j]) ,k=i位运算j。位运算有:与,或,异或。

令序列长度m=2^n,则遍历A时遍历B,需要m^2的复杂度,使用FWT可以用mlog(m)求出这个C序列。

以下是推导过程:

1.jpg

分割线

2.jpg

分割线

3.jpg

分割线

4.jpg

可以用洛谷这道模板题验证FWT程序。www.luogu.com.cn/problem/P47…

/*  
C++17 Standard  
upd:25.09.15  
*/  
#include<bits/stdc++.h>  
using namespace std;  
typedef long long ll;  
typedef unsigned long ul;  
typedef unsigned long long ull;  
  
#define fastio ios::sync_with_stdio(0);cin.tie(0);  
#define fin freopen("D:/in.txt","r",stdin);  
#define fout freopen("D:/out.txt","w",stdout);  
  
const ll maxn=2e5+5,mod=998244353,inv2=499122177;  
ll a[maxn],b[maxn],c[maxn],fwta[maxn],fwtb[maxn],fwtc[maxn];  
ll n,m;  
  
void calc_fwt_or(ll *a,ll *fwt,ll L,ll R) {  
    if(L==R) {  
        fwt[L]=a[L]%mod;  
        return ;  
    }  
    ll mid=(L+R)>>1;  
    calc_fwt_or(a,fwt,L,mid);  
    calc_fwt_or(a,fwt,mid+1,R);  
    for(ll i=mid+1;i<=R;i++){  
        fwt[i]=(fwt[i]+fwt[i-(mid-L+1)])%mod;  
    }  
}  
  
void calc_seq_or(ll *a,ll *fwt,ll L,ll R) {  
    if(L==R) {  
        a[L]=fwt[L]%mod;  
        return ;  
    }  
    ll mid=(L+R)>>1;  
    for(ll i=mid+1;i<=R;i++){  
        fwt[i]=(fwt[i]-fwt[i-(mid-L+1)]+mod)%mod;  
    }  
    calc_seq_or(a,fwt,L,mid);  
    calc_seq_or(a,fwt,mid+1,R);  
}  
  
void calc_fwt_and(ll *a,ll *fwt,ll L,ll R) {  
    if(L==R) {  
        fwt[L]=a[L]%mod;  
        return ;  
    }  
    ll mid=(L+R)>>1;  
    calc_fwt_and(a,fwt,L,mid);  
    calc_fwt_and(a,fwt,mid+1,R);  
    for(ll i=L;i<=mid;i++){  
        fwt[i]=(fwt[i]+fwt[i+(mid-L+1)])%mod;  
    }  
}  
  
void calc_seq_and(ll *a,ll *fwt,ll L,ll R) {  
    if(L==R) {  
        a[L]=fwt[L]%mod;  
        return ;  
    }  
    ll mid=(L+R)>>1;  
    for(ll i=L;i<=mid;i++){  
        fwt[i]=(fwt[i]-fwt[i+(mid-L+1)]+mod)%mod;  
    }  
    calc_seq_and(a,fwt,L,mid);  
    calc_seq_and(a,fwt,mid+1,R);  
}  
  
void calc_fwt_xor(ll *a,ll *fwt,ll L,ll R) {  
    if(L==R) {  
        fwt[L]=a[L]%mod;  
        return ;  
    }  
    ll mid=(L+R)>>1;  
    calc_fwt_xor(a,fwt,L,mid);  
    calc_fwt_xor(a,fwt,mid+1,R);  
    static ll temp_fwt[maxn];  
    for(ll i=L;i<=R;i++) temp_fwt[i]=fwt[i];  
    for(ll i=L;i<=mid;i++){  
        fwt[i]=(temp_fwt[i]+temp_fwt[i+(mid-L+1)])%mod;  
    }  
    for(ll i=mid+1;i<=R;i++) {  
        fwt[i]=(-temp_fwt[i]+temp_fwt[i-(mid-L+1)]+mod)%mod;  
    }  
      
}  
  
void calc_seq_xor(ll *a,ll *fwt,ll L,ll R) {  
    if(L==R) {  
        a[L]=fwt[L]%mod;  
        return ;  
    }  
    ll mid=(L+R)>>1;  
    static ll temp_fwt[maxn];  
    for(ll i=L;i<=mid;i++) {  
        temp_fwt[i]=(fwt[i]+fwt[i+(mid-L+1)])%mod*inv2%mod;  
    }  
    for(ll i=mid+1;i<=R;i++) {  
        temp_fwt[i]=(fwt[i-(mid-L+1)]-fwt[i]+mod)%mod*inv2%mod;  
    }  
    for(ll i=L;i<=R;i++) fwt[i]=temp_fwt[i];  
    calc_seq_xor(a,fwt,L,mid);  
    calc_seq_xor(a,fwt,mid+1,R);      
}  
  
void fwt_or() {  
    //计算fwta,fwtb  
    calc_fwt_or(a,fwta,0,m-1);  
    calc_fwt_or(b,fwtb,0,m-1);  
    for(ll i=0;i<m;i++) fwtc[i]=fwta[i]*fwtb[i]%mod;  
    calc_seq_or(c,fwtc,0,m-1);  
}  
void fwt_and() {  
    //计算fwta,fwtb  
    calc_fwt_and(a,fwta,0,m-1);  
    calc_fwt_and(b,fwtb,0,m-1);  
    for(ll i=0;i<m;i++) fwtc[i]=fwta[i]*fwtb[i]%mod;  
    calc_seq_and(c,fwtc,0,m-1);  
}  
  
void fwt_xor() {  
    //计算fwta,fwtb  
    calc_fwt_xor(a,fwta,0,m-1);  
    calc_fwt_xor(b,fwtb,0,m-1);  
    for(ll i=0;i<m;i++) fwtc[i]=fwta[i]*fwtb[i]%mod;  
    calc_seq_xor(c,fwtc,0,m-1);  
}  
  
int main()  
{  
    cin>>n;  
    m=(1LL<<n);  
    for(ll i=0;i<m;i++) cin>>a[i];  
    for(ll i=0;i<m;i++) cin>>b[i];  
      
    fwt_or();  
      
    //TODO  
    /*  
    printf("ftwa:\n");  
    for(ll i=1;i<=m;i++) cout<<fwta[i]<<" ";  
    cout<<"\n";  
      
    printf("ftwb:\n");  
    for(ll i=1;i<=m;i++) cout<<fwtb[i]<<" ";  
    cout<<"\n";*/  
      
    for(ll i=0;i<m;i++) cout<<c[i]<<" ";  
    cout<<"\n";  
      
    fwt_and();  
    for(ll i=0;i<m;i++) cout<<c[i]<<" ";  
    cout<<"\n";  
      
    fwt_xor();  
    for(ll i=0;i<m;i++) cout<<c[i]<<" ";  
    cout<<"\n";  
      
    return 0;  
}

然后还需要一个FWT相关的性质,用A add B表示A,B序列相同下标的元素相加(A和B序列长度相同),ifwt(A)表示对序列A进行逆FWT变化得到的序列。 那么:

ifwt(A)_i+ifwt(B)_i=ifwt(A add B)_i。

有一个简单的证明方式,由于ifwt(A)_i是A序列中一些固定位置上的数的和与一些固定位置上的数的差,例如(不一定正确,只是个例子)ifwt(A)_1=A_1+A_3-A_2-A_4。ifwt(B)_i只是将序列A换作序列B,则ifwt(B)_1=B_1+B_3-B_2-B_4,那么等式左边是分成A和B两个序列,然后让相应位置上的数加加减减,再将结果合并,等式右边将两个序列合并,然后再让相应位置上的数加加减减,由于加法有交换律,结果相同。

题解

考虑动态规划,设dp[i][j]表示用第1到i堆石子,异或结果为j的方案数。 则要求1<=i<=n,1<=j<=V(V是可能的最大异或结果),dp[i][j]的和。

i=1: dp[1][j]是数量为j的石头堆的数目。

i>=2: dp[i][j]=求和(dp[i-1][j异或t]* dp[1][t]),可以发现这个状态转移方程满足FWT-xor的形式。

dp[1]=ifwt(fwt(dp[1]))

dp[2]=ifwt(fwt(dp[1]) 对应位置相乘 fwt(dp[1]))

dp[3]=ifwt(fwt(dp[2]) 对应位置相乘 fwt(dp[1]))

将dp[2]的表示式带入,dp[3]=ifwt(fwt^3(dp[1]) )

根据“前置知识”中的性质 dp[1]_i+dp[2]_i+...+dp[n]_i 等于ifwt(fwt(dp[1])的每一项等比数列n项求和得到的序列)_i。

此时,复杂度是Vlog(V)。

/*  
using C++17 standard  
upd:25.09.16  
*/  
#include<bits/stdc++.h>  
using namespace std;  
typedef long long ll;  
typedef unsigned long ul;  
typedef unsigned long long ull;  
  
#define fastio ios::sync_with_stdio(0);cin.tie(0);  
#define fin freopen("D:/in.txt","r",stdin);  
#define fout freopen("D:/out.txt","w",stdout);  
  
const ll maxv=(1LL<<16)+5,V=(1LL<<16)-1,mod=998244353;  
ll dp1[maxv],a[maxv],fwt[maxv],ifwt[maxv];  
ll n,k,inv2;  
  
ll binpow(ll a,ll n) {  
    ll res=1;  
    a%=mod;  
    while(n>0) {  
        if(n&1) res=res*a%mod;  
        a=a*a%mod;  
        n>>=1;  
    }  
    return res;  
}  
  
ll get_inv(ll x) {  
    ll res=binpow(x,mod-2);  
    return res;  
}  
  
void calc_fwt_xor(ll *a,ll *fwt,ll L,ll R) {  
    if(L==R) {  
        fwt[L]=a[L]%mod;  
        return ;  
    }  
    ll mid=(L+R)>>1;  
    calc_fwt_xor(a,fwt,L,mid);  
    calc_fwt_xor(a,fwt,mid+1,R);  
    static ll temp_fwt[maxv];  
    for(ll i=L;i<=R;i++) temp_fwt[i]=fwt[i];  
    for(ll i=L;i<=mid;i++){  
        fwt[i]=(temp_fwt[i]+temp_fwt[i+(mid-L+1)])%mod;  
    }  
    for(ll i=mid+1;i<=R;i++) {  
        fwt[i]=(-temp_fwt[i]+temp_fwt[i-(mid-L+1)]+mod)%mod;  
    }  
      
}  
  
void calc_seq_xor(ll *a,ll *fwt,ll L,ll R) {  
    if(L==R) {  
        a[L]=fwt[L]%mod;  
        return ;  
    }  
    ll mid=(L+R)>>1;  
    static ll temp_fwt[maxv];  
    for(ll i=L;i<=mid;i++) {  
        temp_fwt[i]=(fwt[i]+fwt[i+(mid-L+1)])%mod*inv2%mod;  
    }  
    for(ll i=mid+1;i<=R;i++) {  
        temp_fwt[i]=(fwt[i-(mid-L+1)]-fwt[i]+mod)%mod*inv2%mod;  
    }  
    for(ll i=L;i<=R;i++) fwt[i]=temp_fwt[i];  
    calc_seq_xor(a,fwt,L,mid);  
    calc_seq_xor(a,fwt,mid+1,R);      
}  
  
int main()  
{  
    inv2=get_inv(2);  
      
    cin>>n>>k;  
    for(ll i=1;i<=k;i++) cin>>a[i];  
      
    for(ll i=1;i<=k;i++) dp1[a[i]]++;  
      
    calc_fwt_xor(dp1,fwt,0,V);  
      
    for(ll i=0;i<=V;i++) {  
        if(fwt[i]==0continue;  
        if(fwt[i]==1) fwt[i]=n%mod;  
        else {  
            ll res=fwt[i]*(1-binpow(fwt[i],n))%mod*get_inv(1-fwt[i])%mod;  
            //ll res=dp1[i]*(binpow(dp1[i],n)-1)%mod*get_inv(dp1[i]-1)%mod;  
            fwt[i]=res;  
        }  
    }  
      
      
    calc_seq_xor(ifwt,fwt,0,V);  
      
    ll ans=0;  
    for(ll i=1;i<=V;i++) ans=(ans+ifwt[i])%mod;  
    cout<<ans<<"\n";  
      
    return 0;  
}