【Codeforces】Codeforces Round #307 (Div. 2) D.GukiZ and Binary Operations | 矩乘、计数

146 阅读2分钟

本文已参与「新人创作礼」活动, 一起开启掘金创作之路。

【Codeforces】Codeforces Round #307 (Div. 2) D.GukiZ and Binary Operations | 矩乘、计数

题目链接

Problem - 551D - Codeforces

题目

image.png

题目大意

对于给定的 n,k,l,modn,k,l,mod,求构造一个长度为 nn 的满足以下条件的数列 a1,a2,...,ana_1,a_2,...,a_n 的方案数对 modmod 取余的结果:

  1. 0ai<2l0\le a_i<2^l
  2. (a1&a2)(a2&a3)...(an1&an)=k(a_1 \& a_2)|(a_2 \& a_3)|...|(a_{n-1} \& a_n)=k

思路

每个二进制位之间相互独立,可以分别计算。

对于一个串 01 序列,我们带入 (a1&a2)(a2&a3)...(an1&an)(a_1 \& a_2)|(a_2 \& a_3)|...|(a_{n-1} \& a_n),当且仅当存在两个相邻的位置同时为 1 时,最终的计算结果为 1。即我们需要求长度为 nn 的 01 序列中有多少种不包括两个相邻的 1,则包括相邻 1 的答案可以与 2n2^n 作差取得。

长度为 nn 的 01 序列中有多少种不包括两个相邻的 1,容易想到DP:

  • 令 dp[i][0/1] 表示长度为 i,以 0 或 1 结尾的序列有多少种不包含相邻两个 1。
  • dp[i][0]=dp[i-1][0]+dp[i-1][1],如果我们在当前位置放 0,前一位既可以是 0,也可以是 1。
  • dp[i][1]=dp[i-1][0],如果我们在当前位置放 1,前一位只能是 0。

但是 nn 太大了,我们考虑用矩阵乘法快速幂进行优化。容易发现:

[1110][dp[i1][0]dp[i1][1]]=[dp[i1][0]+dp[i1][1]dp[i1][0]]=[dp[i][0]dp[i][1]]\begin{gathered} \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} \begin{bmatrix} dp[i-1][0] \\ dp[i-1][1] \end{bmatrix} = \begin{bmatrix} dp[i-1][0]+dp[i-1][1] \\ dp[i-1][0] \end{bmatrix} = \begin{bmatrix} dp[i][0] \\ dp[i][1] \end{bmatrix} \end{gathered}

进而:

[1110]n1[11]=[dp[n][0]dp[n][1]]\begin{gathered} \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}^{n-1} \begin{bmatrix} 1 \\ 1 \end{bmatrix} = \begin{bmatrix} dp[n][0] \\ dp[n][1] \end{bmatrix} \end{gathered}

所以我们可以在 O(logn)O(logn) 的时间里求出长度为 nn 的 01 序列带入式子计算结果为 1 的方案数。进而求得计算结果为 0 的方案数。

只后我们只需遍历 ll 个二进制位,如果该位 kk 为 0,则乘计算结果为 0 的方案数,否则乘计算结果为 1 的方案数。此外,如果 k>=2lk>=2^l 则无法构造满足条件的序列,方案数为 0。

代码

#include <bits/stdc++.h>
using namespace std;
using LL=long long;
const int N=500001;
LL mod;
struct martix{
	int n,m;
	int a[3][3];
	martix operator * (const martix b) const
	{
		martix ans;
		ans.n=n;
		ans.m=b.m;
		memset(ans.a,0,sizeof(ans.a));
		for (int i=1;i<=n;++i)
			for (int k=1;k<=b.m;++k)
				for (int j=1;j<=m;++j)
					ans.a[i][k]=(ans.a[i][k]+1ll*a[i][j]*b.a[j][k])%mod;
		return ans;
	}
}o;
martix poww(martix a,LL b)
{
	martix ans;
	ans.n=ans.m=2;
	memset(ans.a,0,sizeof(ans.a));
	ans.a[1][1]=ans.a[2][2]=1;
	for (;b;b>>=1,a=a*a)
		if (b&1) ans=ans*a;
	return ans;
}
LL poww(LL a,LL b)
{
	LL ans=1;
	for (;b;b>>=1,a=a*a%mod)
		if (b&1) ans=ans*a%mod;
	return ans;
}
int main()
{
	LL n,k;
	int l;
	cin>>n>>k>>l>>mod;
	if (k>=(1ll<<l)&&l<=60) return printf("0"),0;
	martix f0,sft,fn;
	f0=sft=o;
	f0.n=sft.n=sft.m=2;
	f0.m=f0.a[1][1]=f0.a[2][1]=sft.a[1][1]=sft.a[1][2]=sft.a[2][1]=1;
	fn=poww(sft,n-1)*f0;
	LL cnt0=(fn.a[1][1]+fn.a[2][1])%mod;
	LL cnt1=(poww(2,n)+mod-cnt0)%mod;
	LL ans=1;
	for (int i=0;i<l;++i)
	{
		if ((k>>i)&1) ans=ans*cnt1%mod;
		else ans=ans*cnt0%mod;
	}
	printf("%lld",ans%mod);
	return 0;
}