【CCPC】2022广州站 I. Infection | 动态规划

232 阅读3分钟

本文已参与「新人创作礼」活动, 一起开启掘金创作之路。

【CCPC】2022广州站 I. Infection | 动态规划

题目链接

Problem - I - Codeforces

题目

image.png

题目大意

有一颗含有 nn 个节点 n1n-1 条边的树,根为 1。树上的每个点都有三个权值 ai,bi,cia_i,b_i,c_i。最初有一个节点被染色了,如果节点 ii 的相邻节点被染色了,节点 ii 被染色的概率是 aibi\frac{a_i}{b_i}。每个节点 ii 成为最初被染色的那个节点的概率是 aij=1naj\frac{a_i}{\sum_{j=1}^{n}a_j}

对于 k=1,2,3,...,nk=1,2,3,...,n,求整棵树中恰有 kk 个节点被染色的概率。

思路

考虑动态规划。

我们称全场第一个被染色的节点为关键节点。

f[u][i]f[u][i] 表示 uu 被染色的情况下以 uu 为根的子树中有 ii 个点被染色了,且这 ii 个点都不是关键节点的概率。
g[u][i]g[u][i] 表示 uu 被染色的情况下以 uu 为根的子树中有 ii 个点被染色了,且这 ii 个点存在关键节点的概率。
siz[u]siz[u] 表示以 uu 为根节点的子树的大小。

当访问到节点 uu 时,初始化 f[u][1]=aubu,g[u][1]=auj=1najf[u][1]=\frac{a_u}{b_u},g[u][1]=\frac{a_u}{\sum_{j=1}^{n}a_j}

最初点集 SS 中只包括一个点 uu,假设我们当前已经用 uu 的若干个子节点对 f[u]f[u]d[u]d[u] 进行了转移,当节点 vv 更新过 f[u]f[u]d[u]d[u] 之后,我们把以 vv 为根节点的子树加入到点集 SS 中。

当我们遍历到了 uu 的子节点 vv',此时

  • i=0siz[u]j=0siz[v]f[u][i]×f[v][j]\sum_{i=0}^{siz[u]}\sum_{j=0}^{siz[v]} f[u][i]\times f[v][j] 就是在点集 SS 中选择大小为 i+ji+j 个连通的点被染色了,且这些点不存在关键节点的概率。
  • i=0siz[u]j=0siz[v]g[u][i]×f[v][j]+f[u][i]×g[v][j]\sum_{i=0}^{siz[u]}\sum_{j=0}^{siz[v]} g[u][i]\times f[v][j]+f[u][i]\times g[v][j] 就是在点集 SS 中选择大小为 i+ji+j 个连通的点被染色了,且这些点中有关键节点的概率。

如果我们直接在本地进行转移,在使用 f[u][i]f[u][i] 之前就改变它的值,所以我们需要另外开辅助数组记录转移后的结果,在更新完辅助数组之后再把辅助数组里的值放回到 f[u],g[u]f[u],g[u] 里。

最后我们再令 f[u][0]=1aubuf[u][0]=1-\frac{a_u}{b_u}。以保证在 i0i\neq 0 时我们一定会感染节点 uu

在我们更新完成以 uu 为根的子树中的答案后,将 g[u]g[u] 与点 uu 的父亲节点不被传染染色的概率之积加入到最终的答案数组里。

完成整个 DFS 后输出答案。容易发现每两个点只会在他们的 LCA 出进行一次合并,所以时间复杂度为 O(n2)O(n^2)

代码

#include <iostream>
#include <algorithm>
#include <math.h>
#include <stdio.h>
#include <map>
#include <vector>
#include <queue>
#define nnn printf("No\n")
#define yyy printf("Yes\n")
using namespace std;
using LL=long long;
const int N=2001;
const LL mod=1000000007;
vector<int> e[N];
int n,m,k,x,y,siz[N];
LL a[N],b[N],c[N];
LL poww(LL a,LL b)
{
	LL ans=1;
	for (;b;b>>=1,a=a*a%mod) if (b&1) ans=ans*a%mod;
	return ans;
}
LL f[N][N],g[N][N],ans[N];
LL rf[N],rg[N];
void dfs(int u,int fa)
{
	f[u][1]=b[u];
	g[u][1]=a[u];
	siz[u]=1;
	for (auto v:e[u])
	{
		if (v==fa) continue;
		dfs(v,u);
		for (int i=1;i<=siz[v];++i)
			ans[i]=(ans[i]+(1+mod-b[u])%mod*g[v][i]%mod)%mod;
		for (int i=0;i<=siz[u]+siz[v];++i) rf[i]=rg[i]=0;
		for (int i=0;i<=siz[u];++i)
			for (int j=0;j<=siz[v];++j)
			{
				rf[i+j]=(rf[i+j]+f[u][i]*f[v][j]%mod)%mod;
				rg[i+j]=(rg[i+j]+f[u][i]*g[v][j]%mod+g[u][i]*f[v][j]%mod)%mod;
			}
		siz[u]+=siz[v];
		for (int i=0;i<=siz[u];++i)
		{
			f[u][i]=rf[i];
			g[u][i]=rg[i];
		}
	}
	f[u][0]=(1+mod-b[u])%mod;
}

void solve()
{
	scanf("%d",&n);
	for (int i=1;i<n;++i)
	{
		scanf("%d%d",&x,&y);
		e[x].push_back(y);
		e[y].push_back(x);
	}
	LL sum=0;
	for (int i=1;i<=n;++i)
	{
		scanf("%lld%lld%lld",&a[i],&b[i],&c[i]);
		b[i]=b[i]*poww(c[i],mod-2)%mod;
		sum+=a[i];
		sum%=mod;
	}
	sum=poww(sum,mod-2)%mod;
	for (int i=1;i<=n;++i) a[i]=a[i]*sum%mod;
	dfs(1,0);
	for (int i=1;i<=n;++i)
		printf("%lld\n",(ans[i]+g[1][i])%mod);
}
int main()
{
	int T=1;
	while (T--) solve();
	return 0;
}