本文已参与「新人创作礼」活动, 一起开启掘金创作之路。
【CCPC】2022广州站 I. Infection | 动态规划
题目链接
题目
题目大意
有一颗含有 个节点 条边的树,根为 1。树上的每个点都有三个权值 。最初有一个节点被染色了,如果节点 的相邻节点被染色了,节点 被染色的概率是 。每个节点 成为最初被染色的那个节点的概率是 。
对于 ,求整棵树中恰有 个节点被染色的概率。
思路
考虑动态规划。
我们称全场第一个被染色的节点为关键节点。
表示 被染色的情况下以 为根的子树中有 个点被染色了,且这 个点都不是关键节点的概率。
表示 被染色的情况下以 为根的子树中有 个点被染色了,且这 个点存在关键节点的概率。
表示以 为根节点的子树的大小。
当访问到节点 时,初始化 。
最初点集 中只包括一个点 ,假设我们当前已经用 的若干个子节点对 和 进行了转移,当节点 更新过 和 之后,我们把以 为根节点的子树加入到点集 中。
当我们遍历到了 的子节点 ,此时
- 就是在点集 中选择大小为 个连通的点被染色了,且这些点不存在关键节点的概率。
- 就是在点集 中选择大小为 个连通的点被染色了,且这些点中有关键节点的概率。
如果我们直接在本地进行转移,在使用 之前就改变它的值,所以我们需要另外开辅助数组记录转移后的结果,在更新完辅助数组之后再把辅助数组里的值放回到 里。
最后我们再令 。以保证在 时我们一定会感染节点 。
在我们更新完成以 为根的子树中的答案后,将 与点 的父亲节点不被传染染色的概率之积加入到最终的答案数组里。
完成整个 DFS 后输出答案。容易发现每两个点只会在他们的 LCA 出进行一次合并,所以时间复杂度为 。
代码
#include <iostream>
#include <algorithm>
#include <math.h>
#include <stdio.h>
#include <map>
#include <vector>
#include <queue>
#define nnn printf("No\n")
#define yyy printf("Yes\n")
using namespace std;
using LL=long long;
const int N=2001;
const LL mod=1000000007;
vector<int> e[N];
int n,m,k,x,y,siz[N];
LL a[N],b[N],c[N];
LL poww(LL a,LL b)
{
LL ans=1;
for (;b;b>>=1,a=a*a%mod) if (b&1) ans=ans*a%mod;
return ans;
}
LL f[N][N],g[N][N],ans[N];
LL rf[N],rg[N];
void dfs(int u,int fa)
{
f[u][1]=b[u];
g[u][1]=a[u];
siz[u]=1;
for (auto v:e[u])
{
if (v==fa) continue;
dfs(v,u);
for (int i=1;i<=siz[v];++i)
ans[i]=(ans[i]+(1+mod-b[u])%mod*g[v][i]%mod)%mod;
for (int i=0;i<=siz[u]+siz[v];++i) rf[i]=rg[i]=0;
for (int i=0;i<=siz[u];++i)
for (int j=0;j<=siz[v];++j)
{
rf[i+j]=(rf[i+j]+f[u][i]*f[v][j]%mod)%mod;
rg[i+j]=(rg[i+j]+f[u][i]*g[v][j]%mod+g[u][i]*f[v][j]%mod)%mod;
}
siz[u]+=siz[v];
for (int i=0;i<=siz[u];++i)
{
f[u][i]=rf[i];
g[u][i]=rg[i];
}
}
f[u][0]=(1+mod-b[u])%mod;
}
void solve()
{
scanf("%d",&n);
for (int i=1;i<n;++i)
{
scanf("%d%d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
LL sum=0;
for (int i=1;i<=n;++i)
{
scanf("%lld%lld%lld",&a[i],&b[i],&c[i]);
b[i]=b[i]*poww(c[i],mod-2)%mod;
sum+=a[i];
sum%=mod;
}
sum=poww(sum,mod-2)%mod;
for (int i=1;i<=n;++i) a[i]=a[i]*sum%mod;
dfs(1,0);
for (int i=1;i<=n;++i)
printf("%lld\n",(ans[i]+g[1][i])%mod);
}
int main()
{
int T=1;
while (T--) solve();
return 0;
}