arc096_c_everything on it分析与解答

50 阅读3分钟

提示:容斥原理,斯特林数。

定义选择的拉面组合为S,用A表示某种拉面,|A|=2^N(每种配料选或者不选),|S|=2^|A|(每种拉面选或者不选)。

定义坏事件E(i)(1<=i<=N)为:所有拉面组合S,他们都满足第i种配料只出现了小于2次。

这些坏事件之间有重合,因此考虑容斥原理。

用第一类容斥原理(用于统计不具有性质P1,P2...Pn的对象的个数):

fd225740b1f58320437ad3f4d2cf8377.jpg

接下来看看怎么计算h(i)。

h(i)表示恰好有i种配料不满足要求的拉面组合方法数量,如果在某个拉面组合中,有y种拉面包含这i种配料,那么有0<=y<=i,这是因为这i种配料每种要么出现0次,要么出现1次。利用第二类斯特林数,S(i+1,y+1)就可以表示将这i种配料分到y种拉面中有多少分法,在公式中,让物品和盒子都多一个,就能创造一个“垃圾桶”,放入垃圾桶中的配料在这y种拉面中都没有出现,多出来的那个物品所在盒子被标记为垃圾桶。接下来,这y种拉面还能加其他rest_cnt=N-i种配料,每种配料加或者不加,有2^rest_cnt种搭配,所以在那i种不符合要求的配料在y种拉面中的分配方法定下来后,有(2^rest_cnt)^y种其余配料的分配方法。

剩下的拉面是没有出现i种配料的拉面,有cnt=2^(N-i)种这样的拉面,每种拉面可以选或者不选,所以有2^cnt种组合。

得到最终公式:

image.png

在实现的时候,预处理第二类斯特林数和组合数,使用欧拉/费马降幂计算2^large_number。

  
/*  
using C++17 standard  
upd:25.09.16  
*/  
#include<bits/stdc++.h>  
using namespace std;  
typedef long long ll;  
typedef unsigned long ul;  
typedef unsigned long long ull;  
  
#define fastio ios::sync_with_stdio(0);cin.tie(0);  
#define fin freopen("D:/in.txt","r",stdin);  
#define fout freopen("D:/out.txt","w",stdout);  
  
const ll maxn=3000+5,N=3000;  
ll com[maxn][maxn],stir[maxn][maxn],fact[maxn],factinv[maxn];  
ll n,m;  
  
ll binpow(ll a,ll n,ll mod) {  
    ll res=1;  
    a=a%mod;  
    while(n>0) {  
        if(n&1) res=res*a%mod;  
        a=a*a%mod;  
        n>>=1;  
    }  
    return res;  
}  
  
ll get_com(ll n,ll m,ll mod) {  
    if(m<0 || m>n) return 0;  
    if(m==0 || m==n) return 1;  
    ll res=fact[n]*factinv[n-m]%mod*factinv[m]%mod;  
    return res;  
}  
  
int main()  
{  
    //计算第二类斯特林数  
    cin>>n>>m;  
    stir[0][0]=1;  
    for(ll i=1;i<=N+1;i++) {  
        stir[i][0]=0;  
        for(ll j=1;j<=i;j++) {  
            stir[i][j]=(j*stir[i-1][j]%m+stir[i-1][j-1])%m;  
        }  
    }  
      
    //计算阶乘和逆元  
    fact[0]=1;  
    for(ll i=1;i<=N;i++) fact[i]=fact[i-1]*i%m;  
    for(ll i=1;i<=N;i++) factinv[i]=binpow(fact[i],m-2,m);  
      
    //计算组合数  
    for(ll i=0;i<=N;i++) {  
        for(ll j=0;j<=i;j++) {  
            com[i][j]=get_com(i,j,m);  
        }  
    }  
      
    ll ans=0;  
    for(ll x=0;x<=n;x++) {  
        ll p1=0,p2=0;  
        if(x&1) p1=(-1+m)%m;  
        else p1=1;  
        p1=p1*com[n][x]%m;  
        ll e=binpow(2,n-x,m-1);  
        p1=p1*binpow(2,e,m)%m;  
        for(ll y=0;y<=x;y++) {  
            ll term=stir[x+1][y+1]*binpow(2,(n-x)*y,m)%m;  
            p2=(p2+term)%m;  
        }  
        ans=(ans+p1*p2%m)%m;  
    }  
    cout<<ans<<"\n";  
      
    return 0;  
}