前置知识
前置知识:FWT(快速沃尔什变换),FWT相关性质。
先来看看FWT用于解决什么问题:
给两个长度相同且长度为2的n次方的序列(n>=0),序列下标从0开始到2^n-1,两序列为A和B,定义序列C,C[k]=求和(A[i]* B[j]) ,k=i位运算j。位运算有:与,或,异或。
令序列长度m=2^n,则遍历A时遍历B,需要m^2的复杂度,使用FWT可以用mlog(m)求出这个C序列。
以下是推导过程:
分割线
分割线
分割线
可以用洛谷这道模板题验证FWT程序。www.luogu.com.cn/problem/P47…
/*
C++17 Standard
upd:25.09.15
*/
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long ul;
typedef unsigned long long ull;
#define fastio ios::sync_with_stdio(0);cin.tie(0);
#define fin freopen("D:/in.txt","r",stdin);
#define fout freopen("D:/out.txt","w",stdout);
const ll maxn=2e5+5,mod=998244353,inv2=499122177;
ll a[maxn],b[maxn],c[maxn],fwta[maxn],fwtb[maxn],fwtc[maxn];
ll n,m;
void calc_fwt_or(ll *a,ll *fwt,ll L,ll R) {
if(L==R) {
fwt[L]=a[L]%mod;
return ;
}
ll mid=(L+R)>>1;
calc_fwt_or(a,fwt,L,mid);
calc_fwt_or(a,fwt,mid+1,R);
for(ll i=mid+1;i<=R;i++){
fwt[i]=(fwt[i]+fwt[i-(mid-L+1)])%mod;
}
}
void calc_seq_or(ll *a,ll *fwt,ll L,ll R) {
if(L==R) {
a[L]=fwt[L]%mod;
return ;
}
ll mid=(L+R)>>1;
for(ll i=mid+1;i<=R;i++){
fwt[i]=(fwt[i]-fwt[i-(mid-L+1)]+mod)%mod;
}
calc_seq_or(a,fwt,L,mid);
calc_seq_or(a,fwt,mid+1,R);
}
void calc_fwt_and(ll *a,ll *fwt,ll L,ll R) {
if(L==R) {
fwt[L]=a[L]%mod;
return ;
}
ll mid=(L+R)>>1;
calc_fwt_and(a,fwt,L,mid);
calc_fwt_and(a,fwt,mid+1,R);
for(ll i=L;i<=mid;i++){
fwt[i]=(fwt[i]+fwt[i+(mid-L+1)])%mod;
}
}
void calc_seq_and(ll *a,ll *fwt,ll L,ll R) {
if(L==R) {
a[L]=fwt[L]%mod;
return ;
}
ll mid=(L+R)>>1;
for(ll i=L;i<=mid;i++){
fwt[i]=(fwt[i]-fwt[i+(mid-L+1)]+mod)%mod;
}
calc_seq_and(a,fwt,L,mid);
calc_seq_and(a,fwt,mid+1,R);
}
void calc_fwt_xor(ll *a,ll *fwt,ll L,ll R) {
if(L==R) {
fwt[L]=a[L]%mod;
return ;
}
ll mid=(L+R)>>1;
calc_fwt_xor(a,fwt,L,mid);
calc_fwt_xor(a,fwt,mid+1,R);
static ll temp_fwt[maxn];
for(ll i=L;i<=R;i++) temp_fwt[i]=fwt[i];
for(ll i=L;i<=mid;i++){
fwt[i]=(temp_fwt[i]+temp_fwt[i+(mid-L+1)])%mod;
}
for(ll i=mid+1;i<=R;i++) {
fwt[i]=(-temp_fwt[i]+temp_fwt[i-(mid-L+1)]+mod)%mod;
}
}
void calc_seq_xor(ll *a,ll *fwt,ll L,ll R) {
if(L==R) {
a[L]=fwt[L]%mod;
return ;
}
ll mid=(L+R)>>1;
static ll temp_fwt[maxn];
for(ll i=L;i<=mid;i++) {
temp_fwt[i]=(fwt[i]+fwt[i+(mid-L+1)])%mod*inv2%mod;
}
for(ll i=mid+1;i<=R;i++) {
temp_fwt[i]=(fwt[i-(mid-L+1)]-fwt[i]+mod)%mod*inv2%mod;
}
for(ll i=L;i<=R;i++) fwt[i]=temp_fwt[i];
calc_seq_xor(a,fwt,L,mid);
calc_seq_xor(a,fwt,mid+1,R);
}
void fwt_or() {
//计算fwta,fwtb
calc_fwt_or(a,fwta,0,m-1);
calc_fwt_or(b,fwtb,0,m-1);
for(ll i=0;i<m;i++) fwtc[i]=fwta[i]*fwtb[i]%mod;
calc_seq_or(c,fwtc,0,m-1);
}
void fwt_and() {
//计算fwta,fwtb
calc_fwt_and(a,fwta,0,m-1);
calc_fwt_and(b,fwtb,0,m-1);
for(ll i=0;i<m;i++) fwtc[i]=fwta[i]*fwtb[i]%mod;
calc_seq_and(c,fwtc,0,m-1);
}
void fwt_xor() {
//计算fwta,fwtb
calc_fwt_xor(a,fwta,0,m-1);
calc_fwt_xor(b,fwtb,0,m-1);
for(ll i=0;i<m;i++) fwtc[i]=fwta[i]*fwtb[i]%mod;
calc_seq_xor(c,fwtc,0,m-1);
}
int main()
{
cin>>n;
m=(1LL<<n);
for(ll i=0;i<m;i++) cin>>a[i];
for(ll i=0;i<m;i++) cin>>b[i];
fwt_or();
//TODO
/*
printf("ftwa:\n");
for(ll i=1;i<=m;i++) cout<<fwta[i]<<" ";
cout<<"\n";
printf("ftwb:\n");
for(ll i=1;i<=m;i++) cout<<fwtb[i]<<" ";
cout<<"\n";*/
for(ll i=0;i<m;i++) cout<<c[i]<<" ";
cout<<"\n";
fwt_and();
for(ll i=0;i<m;i++) cout<<c[i]<<" ";
cout<<"\n";
fwt_xor();
for(ll i=0;i<m;i++) cout<<c[i]<<" ";
cout<<"\n";
return 0;
}
然后还需要一个FWT相关的性质,用A add B表示A,B序列相同下标的元素相加(A和B序列长度相同),ifwt(A)表示对序列A进行逆FWT变化得到的序列。 那么:
ifwt(A)_i+ifwt(B)_i=ifwt(A add B)_i。
有一个简单的证明方式,由于ifwt(A)_i是A序列中一些固定位置上的数的和与一些固定位置上的数的差,例如(不一定正确,只是个例子)ifwt(A)_1=A_1+A_3-A_2-A_4。ifwt(B)_i只是将序列A换作序列B,则ifwt(B)_1=B_1+B_3-B_2-B_4,那么等式左边是分成A和B两个序列,然后让相应位置上的数加加减减,再将结果合并,等式右边将两个序列合并,然后再让相应位置上的数加加减减,由于加法有交换律,结果相同。
题解
考虑动态规划,设dp[i][j]表示用第1到i堆石子,异或结果为j的方案数。 则要求1<=i<=n,1<=j<=V(V是可能的最大异或结果),dp[i][j]的和。
i=1: dp[1][j]是数量为j的石头堆的数目。
i>=2: dp[i][j]=求和(dp[i-1][j异或t]* dp[1][t]),可以发现这个状态转移方程满足FWT-xor的形式。
dp[1]=ifwt(fwt(dp[1]))
dp[2]=ifwt(fwt(dp[1]) 对应位置相乘 fwt(dp[1]))
dp[3]=ifwt(fwt(dp[2]) 对应位置相乘 fwt(dp[1]))
将dp[2]的表示式带入,dp[3]=ifwt(fwt^3(dp[1]) )
根据“前置知识”中的性质 dp[1]_i+dp[2]_i+...+dp[n]_i 等于ifwt(fwt(dp[1])的每一项等比数列n项求和得到的序列)_i。
此时,复杂度是Vlog(V)。
/*
using C++17 standard
upd:25.09.16
*/
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long ul;
typedef unsigned long long ull;
#define fastio ios::sync_with_stdio(0);cin.tie(0);
#define fin freopen("D:/in.txt","r",stdin);
#define fout freopen("D:/out.txt","w",stdout);
const ll maxv=(1LL<<16)+5,V=(1LL<<16)-1,mod=998244353;
ll dp1[maxv],a[maxv],fwt[maxv],ifwt[maxv];
ll n,k,inv2;
ll binpow(ll a,ll n) {
ll res=1;
a%=mod;
while(n>0) {
if(n&1) res=res*a%mod;
a=a*a%mod;
n>>=1;
}
return res;
}
ll get_inv(ll x) {
ll res=binpow(x,mod-2);
return res;
}
void calc_fwt_xor(ll *a,ll *fwt,ll L,ll R) {
if(L==R) {
fwt[L]=a[L]%mod;
return ;
}
ll mid=(L+R)>>1;
calc_fwt_xor(a,fwt,L,mid);
calc_fwt_xor(a,fwt,mid+1,R);
static ll temp_fwt[maxv];
for(ll i=L;i<=R;i++) temp_fwt[i]=fwt[i];
for(ll i=L;i<=mid;i++){
fwt[i]=(temp_fwt[i]+temp_fwt[i+(mid-L+1)])%mod;
}
for(ll i=mid+1;i<=R;i++) {
fwt[i]=(-temp_fwt[i]+temp_fwt[i-(mid-L+1)]+mod)%mod;
}
}
void calc_seq_xor(ll *a,ll *fwt,ll L,ll R) {
if(L==R) {
a[L]=fwt[L]%mod;
return ;
}
ll mid=(L+R)>>1;
static ll temp_fwt[maxv];
for(ll i=L;i<=mid;i++) {
temp_fwt[i]=(fwt[i]+fwt[i+(mid-L+1)])%mod*inv2%mod;
}
for(ll i=mid+1;i<=R;i++) {
temp_fwt[i]=(fwt[i-(mid-L+1)]-fwt[i]+mod)%mod*inv2%mod;
}
for(ll i=L;i<=R;i++) fwt[i]=temp_fwt[i];
calc_seq_xor(a,fwt,L,mid);
calc_seq_xor(a,fwt,mid+1,R);
}
int main()
{
inv2=get_inv(2);
cin>>n>>k;
for(ll i=1;i<=k;i++) cin>>a[i];
for(ll i=1;i<=k;i++) dp1[a[i]]++;
calc_fwt_xor(dp1,fwt,0,V);
for(ll i=0;i<=V;i++) {
if(fwt[i]==0) continue;
if(fwt[i]==1) fwt[i]=n%mod;
else {
ll res=fwt[i]*(1-binpow(fwt[i],n))%mod*get_inv(1-fwt[i])%mod;
//ll res=dp1[i]*(binpow(dp1[i],n)-1)%mod*get_inv(dp1[i]-1)%mod;
fwt[i]=res;
}
}
calc_seq_xor(ifwt,fwt,0,V);
ll ans=0;
for(ll i=1;i<=V;i++) ans=(ans+ifwt[i])%mod;
cout<<ans<<"\n";
return 0;
}