本文已参与「新人创作礼」活动, 一起开启掘金创作之路。
【CCPC】2022桂林站 G. Group Homework | 换根DP
题目链接
题目
题目大意
给出一颗有 个节点的无根树,第 个节点的权值为 。在树上选两条简单路径,最大化仅被一条选中的简单路径覆盖的点的权值和。
思路
首先,这两条简单路径最多有一个交点。因为如果最优解选中的两条简单路径有两个或以上的交点,如下图所示,方案一显然不如方案二优:
则对于有一个交点的情况,我们可以枚举交点,则答案是以交点为端点的最长的四条带权链的长度和;对于没有交点的情况我们只需要枚举切断哪条边,然后答案是形成的两棵树的带权直径之和。这两件事都可以用换根 DP 求解。下面进行详细说明。
我们可以先以节点 1 为根节点把整棵树拎起来,令 表示与节点 相邻的点的集合, 表示节点 的父节点,则可以预处理出以每个节点 为根节点的子树中的以下信息:
- 以 为端点的最长链的长度 ;
- 所有子节点中 值前四大的节点下标 ,遍历子节点时顺次更新;
- 子树 的直径 ;
- 所有子节点中 值最大和次大的节点下标 ,遍历子节点时顺次更新。
对于有一个交点的情况,以当前遍历到的节点 为整棵树的根,则其树 除去子树 剩下的部分就成为了 的一棵子树。以其 为根的子树的最长链长度以参数形式传来,用端点为 没有公共点的四条最长带权链的长度和更新答案。接着枚举整个 的所有子节点 ,如果 是其 值最大的子节点,则树 除去子树 剩下的部分以 为端点的最长带权链的长度就是 ,否则是 。递归计算即可。
枚举切掉的边与上述情况类似。
输出统计的答案。
代码
#include <bits/stdc++.h>
#define nnn printf("No\n")
#define yyy printf("Yes\n")
using namespace std;
using LL=long long;
const int N=200001;
int a[N],n,x,y;
int dp[N],ms[N][4],w[N],mss[N][2];
vector<int> e[N];
void getmax4(int u,int fa)
{
int x;
for (auto v:e[u])
{
if (v==fa) continue;
getmax4(v,u);
dp[u]=max(dp[u],dp[v]);
w[u]=max(w[u],w[v]);
x=v;
for (int i=0;i<4;++i)
if (dp[x]>dp[ms[u][i]]) swap(x,ms[u][i]);
x=v;
for (int i=0;i<2;++i)
if (w[x]>w[mss[u][i]]) swap(x,mss[u][i]);
}
dp[u]+=a[u];
w[u]=max(w[u],dp[ms[u][0]]+dp[ms[u][1]]+a[u]);
}
int ans=0;
void dfscv(int u,int fa,int d)
{
ans=max(ans,dp[ms[u][0]]+dp[ms[u][1]]+dp[ms[u][2]]+max(d,dp[ms[u][3]]));
for (auto v:e[u])
{
if (v==fa) continue;
if (ms[u][0]==v) dfscv(v,u,max(dp[ms[u][1]],d)+a[u]);
else dfscv(v,u,max(dp[ms[u][0]],d)+a[u]);
}
}
void dfsce(int u,int fa,int d1,int d2)
{
int dd1,dd2;
for (auto v:e[u])
{
if (v==fa) continue;
if (v==ms[u][0]) dd1=max(dp[ms[u][1]],d1)+a[u];
else dd1=max(dp[ms[u][0]],d1)+a[u];
if (v==mss[u][0]) dd2=w[mss[u][1]];
else dd2=w[mss[u][0]];
if (v==ms[u][0]) dd2=max(dd2,dp[ms[u][1]]+max(dp[ms[u][2]],d1)+a[u]);
else if (v==ms[u][1]) dd2=max(dd2,dp[ms[u][0]]+max(dp[ms[u][2]],d1)+a[u]);
else dd2=max(dd2,dp[ms[u][0]]+max(dp[ms[u][1]],d1)+a[u]);
ans=max(ans,w[v]+dd2);
dfsce(v,u,dd1,dd2);
}
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=n;++i) scanf("%d",&a[i]);
for (int i=1;i<n;++i)
{
scanf("%d%d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
getmax4(1,0);
dfscv(1,0,0);
dfsce(1,0,0,0);
cout<<ans;
return 0;
}