【CCPC】2022桂林站 G. Group Homework | 换根DP

406 阅读3分钟

本文已参与「新人创作礼」活动, 一起开启掘金创作之路。

【CCPC】2022桂林站 G. Group Homework | 换根DP

题目链接

Problem - G - Codeforces

题目

image.png

题目大意

给出一颗有 nn 个节点的无根树,第 ii 个节点的权值为 aia_i。在树上选两条简单路径,最大化仅被一条选中的简单路径覆盖的点的权值和。

思路

首先,这两条简单路径最多有一个交点。因为如果最优解选中的两条简单路径有两个或以上的交点,如下图所示,方案一显然不如方案二优: image.png

则对于有一个交点的情况,我们可以枚举交点,则答案是以交点为端点的最长的四条带权链的长度和;对于没有交点的情况我们只需要枚举切断哪条边,然后答案是形成的两棵树的带权直径之和。这两件事都可以用换根 DP 求解。下面进行详细说明。

我们可以先以节点 1 为根节点把整棵树拎起来,令 son[u]son[u] 表示与节点 uu 相邻的点的集合,f[u]f[u] 表示节点 uu 的父节点,则可以预处理出以每个节点 uu 为根节点的子树中的以下信息:

  • uu 为端点的最长链的长度 dp[u]=maxvson[u],vf[u]{dp[v]}+a[u]dp[u]=\max_{v\in son[u],v\neq f[u]}\{dp[v]\}+a[u]
  • uu 所有子节点中 dpdp 值前四大的节点下标 ms[u][0,1,2,3]ms[u][0,1,2,3],遍历子节点时顺次更新;
  • 子树 uu 的直径 w[u]=max(maxvson[u],vf[u]{w[v]},dp[ms[u][0]]+dp[ms[u][1]]+a[u])w[u]=\max(\max_{v\in son[u],v\neq f[u]}\{w[v]\},dp[ms[u][0]]+dp[ms[u][1]]+a[u])
  • uu 所有子节点中 ww 值最大和次大的节点下标 mss[u][0,1]mss[u][0,1],遍历子节点时顺次更新。

对于有一个交点的情况,以当前遍历到的节点 uu 为整棵树的根,则其树 f[u]f[u] 除去子树 uu 剩下的部分就成为了 uu 的一棵子树。以其 f[u]f[u] 为根的子树的最长链长度以参数形式传来,用端点为 uu 没有公共点的四条最长带权链的长度和更新答案。接着枚举整个 uu 的所有子节点 vv,如果 vv 是其 dpdp 值最大的子节点,则树 uu 除去子树 vv 剩下的部分以 uu 为端点的最长带权链的长度就是 max(dp[ms[u][1]],d)+a[u]max(dp[ms[u][1]],d)+a[u],否则是 max(dp[ms[u][0]],d)+a[u]max(dp[ms[u][0]],d)+a[u]。递归计算即可。

枚举切掉的边与上述情况类似。

输出统计的答案。

代码

#include <bits/stdc++.h>
#define nnn printf("No\n")
#define yyy printf("Yes\n")
using namespace std;
using LL=long long;
const int N=200001;
int a[N],n,x,y;
int dp[N],ms[N][4],w[N],mss[N][2];
vector<int> e[N];
void getmax4(int u,int fa)
{
	int x;
	for (auto v:e[u])
	{
		if (v==fa) continue;
		getmax4(v,u);
		dp[u]=max(dp[u],dp[v]);
		w[u]=max(w[u],w[v]);
		x=v;
		for (int i=0;i<4;++i)
			if (dp[x]>dp[ms[u][i]]) swap(x,ms[u][i]);
		x=v;
		for (int i=0;i<2;++i)
			if (w[x]>w[mss[u][i]]) swap(x,mss[u][i]);
	}
	dp[u]+=a[u];
	w[u]=max(w[u],dp[ms[u][0]]+dp[ms[u][1]]+a[u]);
}
int ans=0;
void dfscv(int u,int fa,int d)
{
	ans=max(ans,dp[ms[u][0]]+dp[ms[u][1]]+dp[ms[u][2]]+max(d,dp[ms[u][3]]));
	for (auto v:e[u])
	{
		if (v==fa) continue;
		if (ms[u][0]==v) dfscv(v,u,max(dp[ms[u][1]],d)+a[u]);
		else dfscv(v,u,max(dp[ms[u][0]],d)+a[u]);
	}
}
void dfsce(int u,int fa,int d1,int d2)
{
	int dd1,dd2;
	for (auto v:e[u])
	{
		if (v==fa) continue;
		if (v==ms[u][0]) dd1=max(dp[ms[u][1]],d1)+a[u];
		else dd1=max(dp[ms[u][0]],d1)+a[u];
		if (v==mss[u][0]) dd2=w[mss[u][1]];
		else dd2=w[mss[u][0]];
		if (v==ms[u][0]) dd2=max(dd2,dp[ms[u][1]]+max(dp[ms[u][2]],d1)+a[u]);
		else if (v==ms[u][1]) dd2=max(dd2,dp[ms[u][0]]+max(dp[ms[u][2]],d1)+a[u]);
		else dd2=max(dd2,dp[ms[u][0]]+max(dp[ms[u][1]],d1)+a[u]);
		ans=max(ans,w[v]+dd2);
		dfsce(v,u,dd1,dd2);
	}
}
int main()
{
	scanf("%d",&n);
	for (int i=1;i<=n;++i) scanf("%d",&a[i]);
	for (int i=1;i<n;++i)
	{
		scanf("%d%d",&x,&y);
		e[x].push_back(y);
		e[y].push_back(x);
	}
	getmax4(1,0);
	dfscv(1,0,0);
	dfsce(1,0,0,0);
	cout<<ans;
	return 0;
}