【Codeforces】Polynomial Round 2022 E. Two Chess Pieces | 树、倍增

163 阅读3分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第23天,点击查看活动详情

【Codeforces】Polynomial Round 2022 E. Two Chess Pieces | 树、倍增

题目链接

Problem - 1774E - Codeforces

题目

image.png

题目大意

一棵有 nn 个节点的树,你可以对它进行操作。

最初,树的节点 11 上有两个棋子。在一次操作中,可以选择任意棋子,并将其移动到相邻节点。需要确保任意时刻两个棋子之间的距离不会超过给定的正整数 dd

棋子分别有一个必须要访问的节点列表,访问顺序随意。最终,两个棋子必须返回根节点。最少需要的操作步数是多少?

思路

假设棋子 00 必须要访问的节点为 {x1,x2,...,xn0}\{x_1,x_2,...,x_{n_0}\},棋子 11 必须要访问的节点为 {y1,y2,...,yn1}\{y_1,y_2,...,y_{n_1}\},如果不存在两个节点距离不能超过 dd 的限制,那么两个棋子的答案相互独立。

只考虑棋子 00 的答案,最优解形如 DFS 整棵树,走到当前路径上最深的必须访问节点就折返:我们只需要给棋子 00 必须访问的每个节点都打上标记,如果一条边通向的子树中存在带有标记的点,这条边就会被访问两次。所以记录满足上述条件的边的数量,乘 22 即为棋子 00 对答案的贡献,可以通过遍历一遍树求得。棋子 11 同理。

现在加上距离不能超过 dd 的限制,我们把棋子 11 必须要访问的点的第 dd 层祖先列为棋子 00 必须访问的点,将棋子 00 必须访问的节点的 dd 层祖先也加入到棋子 11 必须访问的点。显然按最优的每条边恰访问两次的策略可以满足两棋子距离不超过 dd,可以忽略该条件,按之前介绍的步骤分别求解两棋子对答案的贡献。

代码

#include <stdio.h>
#include <algorithm>
#include <string.h>
#include <iostream>
#include <math.h>
#include <map>
#include <queue>
#include <vector>
#include <stdlib.h>
#include <time.h>
using namespace std;
using LL=long long;
const int N=1e6+5;
//const LL mod=998244353;
const LL mod=1e9+7;
int n,f[N][31],k;
vector<int> e[N];
int v[2][N],w[2][N],m[2],dep[N];
void getfa(int u,int fa)
{
	for (auto vv:e[u])
	{
		if (vv==fa) continue;
		dep[vv]=dep[u]+1;
		f[vv][0]=u;
		getfa(vv,u);
	}
}
int gtf(int u,int cnt)
{
	for (int i=20;i>=0;--i)
		if ((cnt>>i)&1) u=f[u][i];
	return u;
}
int ans=0;
int solve(int u,int fa,int t)
{
	int flag=0;
	for (auto vv:e[u])
	{
		if (vv==fa) continue;
		if (solve(vv,u,t))
		{
			flag=1;
			ans+=2;
		}
	}
	if (v[t][u]==1) flag=1;
	return flag;
}
int main()
{
	scanf("%d%d",&n,&k);
	for (int i=1,x,y;i<n;++i)
	{
		scanf("%d%d",&x,&y);
		e[x].push_back(y);
		e[y].push_back(x);
	}
	for (int j=0;j<=1;++j)
	{
		scanf("%d",&m[j]);
		for (int i=1;i<=m[j];++i)
		{
			scanf("%d",&w[j][i]);
			v[j][w[j][i]]=1;
		}
	}
	getfa(1,0);
	for (int t=0;t<20;++t)
		for (int i=1;i<=n;++i) f[i][t+1]=f[f[i][t]][t];
	for (int j=0;j<=1;++j)
	{
		for (int i=1,p;i<=m[j];++i)
		{
			p=gtf(w[j][i],k);
			if (p!=0) v[!j][p]=1;
		}
	}
	solve(1,0,0);
	solve(1,0,1);
	printf("%d\n",ans);
	return 0;
}