例题大概意思就是有一颗有 n 个顶点的树,其中连接顶点 a_i 和 b_i 的边 i 的长度为 l ,然后统计最短距离不超过 k 的顶点的对数
(虽然篇幅比较长,但是看完会有收获的)
树上的分治,与其他的分治意思相同,都是把问题分而治之,比如数列分治,我们直接从中间一分为二,平面分治,我们左右或上下一分为二,那树上分治我们怎么分才合理呢
我们需要考虑子问题大小,数列分治和平面分治都是将问题规模缩小了一半的大小,这样才把递归深度控制在O(logn) ,那树我们也需将子树的大小控制在不大于 n / 2 ,不然会发生递归深度退化,这时我们需要引入重心的概念
👌
重心:
重心定义:我们称若删除某顶点后得到的最大子树的顶点数最少,则该顶点为此树重心,不难得出删除该点后子树得顶点数一定不超过 n / 2
如何得到:若想要找到重心,我们可以沿着最大子树方向移动(向最大子树方向移动,便相当于减小最大子树大小,直至找到当移动后的点,最大子树大小大于移动前的那点。其他非最大子树方向不可能存在重心,因为向非最大子树方向移动,最大子树的大小会越来越大)
(如何寻找重心,后面有详细代码解析)
👌
既然我们知道了如何找到重心,现在我们需要考虑:
怎么分而治之
我们将一棵树分为一个重心S,和若干个子树
(下图相当于分为一个重心和四个子树)
👌
现在顶点之间有
三种情况
(1)顶点对中两点同在一个子树中
(2)顶点对中两点在不同子树中
(3)顶点对中两点一个在子树中,一个为重心
(注意:数组和平面分治的一分为二里,都不会单独删除某点,而树的分治相当于分为一个重心和若干个子树)
👌
如何处理情况
对于(1)我们递归可得
对于(2),我们对于每个子树进行操作一次,遍历子树中所有的顶点,将各个点的到 重心S 的距离记录一下,然后若不同子树中的两个点到 重心S 的距离和大于K,则结果加 1
对于(3),只要某个距离记录大于 K ,则结果加 1 即可,为了方便操作我们也可以在距离记录中增加一个 0 ,相当于添加一个新的只有一个顶点的子树,且这个顶点到 重心S 的距离为 0
👌
代码思路
①我们先遍历当前这个树的所有顶点,记录以各个顶点为根时,该子树上的顶点数目(数组count_point保存)
②找到当前这个树的重心(沿最大子树遍历下去)
③对以重心连接几个点为根的子树递归下去((1)情况)
④分别对于每个子树进行遍历,将上面所有点记录下
⑤对于记录下的点,我们进行判断,若距离和不大于 K ,结果加 1((2)情况)
注意:由于我们距离记录中,可能同一子树的顶点对的距离和也不大于 k 被误判,所以我们在④中也要进行判断两点距离和,并减去所得到的结果,即相当于对⑤中的结果中误判部分抵消
注意:但我们找到重心时,我们要把它标记出来,相当于把它从树中删去,这样才相当于将几个的子树分开了,但我们需考虑我们之前删除的重心会一直影响到之后的操作,但是之后的删除的重心不能影响之前的操作,所以在之后一定要记得把重心 ”归还“ 回树
相当于我们处理当重心为 S2 时,重心 S1 需要存在,否则 S2 的右上方子树会变的比真实情况要大,会导致子树不正确
但是当我们回溯会处理重心 S1 的时候,重心 S2 不能存在,若此时 S2 还存在,则 S1 的左下方子树会变的比真实情况要小,会导致子树不正确
👌
如何寻找重心解析:
第①种方法:
沿着最大树方向走
参数 n 为当前点, 参数 last 为上一个点
参数 f 记录了该树的根(主要用处为了记录该树的顶点数)
为什么要知道该树的顶点数?
因为我们不知道以 last 为根的子树的顶点数,所以需要 ( |该树的顶点数| - |n连接的除last外所有顶点为根时顶点数和| - 1) 来求
(那个 1 是指被删除的重心)
如图:即我们从 f 出发,遍历到 n ,此时我们不知道以 last 为根的子树的顶点数
这样我们通过比较,便可以得到当以 n 为重心时,最大子树的顶点数的大小
又因为我们是沿着最大子树走,所以以 last 为重心时,以 n 为根的顶点数就是最大子树的顶点数的大小
所以我们可以通过比较两个顶点数大小,来判断是否往下继续走
int find_core(int n, int last, int f)
{
int flag = -1, c = 0;
for(int i = 0; i < data[n].size(); i++)
{
int w = data[n][i].to;
if(w != last && !delete_point[w])
{
if(flag == -1)
{
flag = i;
}
else
{
if(count_point[w] > count_point[data[n][flag].to])
{
flag = i;
}
}
c += count_point[w];
}
}
if(flag == -1 && last == -1)
{
return n;
}
else if(flag == -1)
{
return last;
}
c = max(count_point[data[n][flag].to], count_point[f] - c - 1);
if(c >= count_point[n])
{
return last;
}
return find_core(data[n][flag].to, n, f);
}
第②种方法:
遍历所有点,不光沿着最大子树走,这样复杂度也只是O(n),不会超时,并且代码相对简单
(我们最终取递归结果的 .second 即可)
pair<int, int> find_core(int v, int p, int t)
{
pair<int, int> res = make_pair(INF, -1);
int s = 1, m = 0;
for(int i = 0; i < data[v].size(); i++)
{
int w = data[v][i].to;
if(w == p || delete_point[w]) continue;
res = min(res, find_core(w, v, t));
m = max(m, count_point[w]);
s += count_point[w];
}
m = max(m, t - s);
res = min(res, make_pair(m, v));
return res;
}
👌
AC代码
(如果还有不会的地方,可以评论或者私信,我们可以一起探讨)
#include <iostream>
#include <stdio.h>
#include <vector>
#include <algorithm>
#define INF 1000000005
using namespace std;
struct edge
{
int to, length;
edge(int t, int l)
{
to = t; length = l;
}
};
vector<struct edge> data[10005];
int N, K;
int count_point[10005];
bool delete_point[10005];
void add_edge(int from, int to, int length)
{
data[from].push_back(edge(to, length));
data[to].push_back(edge(from, length));
}
void tag_point(int n, int last)
{
int ans = 1;
for(int i = 0; i < data[n].size(); i++)
{
int w = data[n][i].to;
if(last != w && !delete_point[w])
{
tag_point(w, n);
ans += count_point[w];
}
}
count_point[n] = ans;
}
int find_core(int n, int last, int f)
{
int flag = -1, c = 0;
for(int i = 0; i < data[n].size(); i++)
{
int w = data[n][i].to;
if(w != last && !delete_point[w])
{
if(flag == -1)
{
flag = i;
}
else
{
if(count_point[w] > count_point[data[n][flag].to])
{
flag = i;
}
}
c += count_point[w];
}
}
if(flag == -1 && last == -1)
{
return n;
}
else if(flag == -1)
{
return last;
}
c = max(count_point[data[n][flag].to], count_point[f] - c - 1);
if(c >= count_point[n])
{
return last;
}
return find_core(data[n][flag].to, n, f);
}
void measure(int n, int p, int l, vector<int> &d)
{
d.push_back(l);
for(int i = 0; i < data[n].size(); i++)
{
int w = data[n][i].to;
if(w != p && !delete_point[w])
{
measure(w, n, l + data[n][i].length, d);
}
}
}
int count_num(vector<int> &d)
{
int res = 0, num = d.size();
sort(d.begin(), d.end());
int j = num - 1;
for(int i = 0; i < num && i < j; i++)
{
for(; d[i] + d[j] > K && i < j; j--){}
res += j - i;
}
return res;
}
int solve(int n)
{
int res = 0;
tag_point(n, -1);
int core = find_core(n, -1, n);
delete_point[core] = true;
for(int i = 0; i < data[core].size(); i++)
{
int w = data[core][i].to;
if(!delete_point[w])
{
res += solve(data[core][i].to);
}
}
vector<int> point_sum;
point_sum.push_back(0);
for(int i = 0; i < data[core].size(); i++)
{
int w = data[core][i].to;
if(!delete_point[w])
{
vector<int> point;
measure(data[core][i].to, core, data[core][i].length, point);
res -= count_num(point);
point_sum.insert(point_sum.end(), point.begin(), point.end());
}
}
delete_point[core] = false;
res += count_num(point_sum);
return res;
}
int main()
{
while(1)
{
scanf("%d %d", &N, &K);
if(N == 0 && K == 0)
{
break;
}
for(int i = 0; i < N; i++)
{
data[i].clear();
}
fill(delete_point, delete_point + N, false);
for(int i = 0; i < N - 1; i++)
{
int a, b, l;
scanf("%d %d %d", &a, &b, &l);
add_edge(a - 1, b - 1, l);
}
printf("%d\n", solve(0));
}
}