【Codeforces】Codeforces Global Round 23 D - Paths on the Tree

229 阅读3分钟

本文已参与「新人创作礼」活动, 一起开启掘金创作之路。

【Codeforces】Codeforces Global Round 23 D - Paths on the Tree

题目链接

Problem - D - Codeforces

A Maxmina

题目大意

一棵具有 nn 个节点有根树,根节点为 11,第 ii 个节点的权值为 sis_i

在该树中选择 kk 条路径,要求满足以下要求:

  • 必须以节点 11 为路径的端点。
  • cic_i 表示所有被选择的路径中覆盖顶点 ii 的路径条数,对于任一节点,它的每一对子节点 v1,v2v_1,v_2 都应该满足 cv1cv21|c_{v_1}-c_{v_2}|\le 1

在满足上述条件的情况下选择 kk 条路径,最大化 i=1nci×si\sum_{i=1}^nc_i\times s_i

思路

显然一定有 c1=kc_1=k。因为所有点的权值非负,所以我们一定每次都选择从根到叶子节点的路径。

记节点 ii 子节点的数量为 cnticnt_i。考虑节点 11 的所有孩子,为了取得最优解,它们被覆盖过的次数一定是 kcnt1\frac{k}{cnt_1} 或者 kcnt1+1\frac{k}{cnt_1}+1

如果非叶子节点 uu 被覆盖过的次数只可能是 kuk_u 或者 ku+1k_u+1,则:

  • uu 被覆盖过的次数是 kuk_u 时,
    为了满足条件 cv1cv21|c_{v_1}-c_{v_2}|\le 1
    节点 uu 的子节点 vv 被覆盖过的次数将会先平分 kk,即被覆盖 t=kucntut=\frac{k_u}{cnt_u} 次。
    为了取得最优解,将会再选择 uud=(kumodcntu)d=(k_u\mod cnt_u) 个子节点再多访问一次。
    vv 只可能被访问 tt 次或 t+1t+1 次;
  • uu 被覆盖过的次数是 ku+1k_u+1 时,
    要么 ku+1cntu=t,(ku+1)modcntu=d+1\frac{k_u+1}{cnt_u}=t,(k_u+1)\mod cnt_u=d+1
    要么 ku+1cntu=t+1,(ku+1)modcntu=0\frac{k_u+1}{cnt_u}=t+1,(k_u+1)\mod cnt_u=0
    两种可能性 vv 都只可能被访问 tt 次或 t+1t+1 次;

所以如果非叶子节点 uu 被覆盖过的次数只可能是 kuk_u 或者 ku+1k_u+1,则其子节点 vv 只可能被访问 tt 次或 t+1t+1 次。又因节点 11 的所有孩子,被覆盖过的次数一定是 kcnt1\frac{k}{cnt_1} 或者 kcnt1+1\frac{k}{cnt_1}+1,所以树中所有点被覆盖的次数都只可能有两种取值。

我们可以记 dp[u][0/1]dp[u][0/1] 表示以 uu 为根节点的子树被覆盖了 t/t+1t/t+1 次的最优解, vvuu 的子节点。则:

dp[u][0]=suki+dp[v][0]+d(dp[v][1]dp[v][0])dp[u][0]=su(ki+1)+dp[v][0]+d+1(dp[v][1]dp[v][0])dp[u][0]=s_u*k_i+\sum{dp[v][0]}+\sum_{前 d 大} (dp[v][1]-dp[v][0])\\ dp[u][0]=s_u*(k_i+1)+\sum{dp[v][0]}+\sum_{前 d+1 大} (dp[v][1]-dp[v][0])

怎么求 d(dp[v][1]dp[v][0])\sum_{前 d 大} (dp[v][1]-dp[v][0]) 呢?

我们只需要在 DFS 的过程中把当前节点的所有子节点的 dpdp 值求完之后,将 (dp[v][1]dp[v][0])(dp[v][1]-dp[v][0]) 压入数组进行排序即可。因为树中所有节点的子节点数量之和为 n1n-1 个,整体的时间复杂度为 O(nlogn)O(nlogn)

代码

#include <bits/stdc++.h>
using namespace std;
using LL=long long;
const int N=500001;
const LL mod=1000000007;
//const LL mod=998244353;
vector<int> e[N];
int f[N];
int n,m;
LL k;
LL a[N],dp[N][2],ans,qwq[N];
int tot;
void dfs(int u,long long k)
{
	dp[u][0]=a[u]*k;
	dp[u][1]=a[u]*(k+1);
	if (e[u].size()==0) return;
	int t=k/e[u].size();
	for (auto v:e[u])
	{
		dfs(v,t);
		dp[u][0]+=dp[v][0];
		dp[u][1]+=dp[v][0];
	}
	tot=0;
	for (auto v:e[u]) qwq[++tot]=max(0ll,dp[v][1]-dp[v][0]);
	sort(qwq+1,qwq+1+tot);
	t=k%e[u].size();
	for (int i=tot-t+1;i<=tot;++i) 
	{
		dp[u][1]+=qwq[i];
		dp[u][0]+=qwq[i];
	}
	dp[u][1]+=qwq[tot-t];
}
int solve()
{
	scanf("%d",&n);
	scanf("%lld",&k);
	for (int i=2;i<=n;++i)
	{
		scanf("%d",&f[i]);
		e[f[i]].push_back(i);
	}
	for (int i=1;i<=n;++i) scanf("%lld",&a[i]);
	dfs(1,k);
	printf("%lld\n",dp[1][0]);
	for (int i=1;i<=n;++i)
	{
		e[i].clear();
		dp[i][0]=dp[i][1]=0;
		f[i]=0;
	}
	return 0;
}
int main()
{
	int T=1;
	for (scanf("%d",&T);T--;) solve();
	return 0;
}