树上分治详解 超级详细(附带例题 poj1741(给了题目))

57 阅读6分钟

例题大概意思就是有一颗有 n 个顶点的树,其中连接顶点 a_i 和 b_i 的边 i 的长度为 l ,然后统计最短距离不超过 k 的顶点的对数

(虽然篇幅比较长,但是看完会有收获的)

树上的分治,与其他的分治意思相同,都是把问题分而治之,比如数列分治,我们直接从中间一分为二,平面分治,我们左右或上下一分为二,那树上分治我们怎么分才合理呢

我们需要考虑子问题大小,数列分治和平面分治都是将问题规模缩小了一半的大小,这样才把递归深度控制在O(logn) ,那树我们也需将子树的大小控制在不大于 n / 2 ,不然会发生递归深度退化,这时我们需要引入重心的概念

👌

重心:

重心定义我们称若删除某顶点后得到的最大子树的顶点数最少,则该顶点为此树重心,不难得出删除该点后子树得顶点数一定不超过 n / 2
如何得到若想要找到重心,我们可以沿着最大子树方向移动(向最大子树方向移动,便相当于减小最大子树大小,直至找到当移动后的点,最大子树大小大于移动前的那点。其他非最大子树方向不可能存在重心,因为向非最大子树方向移动,最大子树的大小会越来越大)
(如何寻找重心,后面有详细代码解析)

👌

既然我们知道了如何找到重心,现在我们需要考虑:

怎么分而治之

我们将一棵树分为一个重心S,和若干个子树
(下图相当于分为一个重心和四个子树)
在这里插入图片描述

👌
现在顶点之间有

三种情况

(1)顶点对中两点同在一个子树中
(2)顶点对中两点在不同子树中
(3)顶点对中两点一个在子树中,一个为重心
(注意:数组和平面分治的一分为二里,都不会单独删除某点,而树的分治相当于分为一个重心和若干个子树)

👌

如何处理情况

对于(1)我们递归可得
对于(2),我们对于每个子树进行操作一次,遍历子树中所有的顶点,将各个点的到 重心S 的距离记录一下,然后若不同子树中的两个点到 重心S 的距离和大于K,则结果加 1
对于(3),只要某个距离记录大于 K ,则结果加 1 即可,为了方便操作我们也可以在距离记录中增加一个 0 ,相当于添加一个新的只有一个顶点的子树,且这个顶点到 重心S 的距离为 0

👌

代码思路

①我们先遍历当前这个树的所有顶点,记录以各个顶点为根时,该子树上的顶点数目(数组count_point保存)
②找到当前这个树的重心(沿最大子树遍历下去)
③对以重心连接几个点为根的子树递归下去((1)情况)
④分别对于每个子树进行遍历,将上面所有点记录下
⑤对于记录下的点,我们进行判断,若距离和不大于 K ,结果加 1((2)情况)

注意:由于我们距离记录中,可能同一子树的顶点对的距离和也不大于 k 被误判,所以我们在④中也要进行判断两点距离和,并减去所得到的结果,即相当于对⑤中的结果中误判部分抵消

注意:但我们找到重心时,我们要把它标记出来,相当于把它从树中删去,这样才相当于将几个的子树分开了,但我们需考虑我们之前删除的重心会一直影响到之后的操作,但是之后的删除的重心不能影响之前的操作,所以在之后一定要记得把重心 ”归还“ 回树
请添加图片描述
相当于我们处理当重心为 S2 时,重心 S1 需要存在,否则 S2 的右上方子树会变的比真实情况要大,会导致子树不正确
但是当我们回溯会处理重心 S1 的时候,重心 S2 不能存在,若此时 S2 还存在,则 S1 的左下方子树会变的比真实情况要小,会导致子树不正确

👌

如何寻找重心解析:

第①种方法:
沿着最大树方向走
参数 n 为当前点, 参数 last 为上一个点
参数 f 记录了该树的根(主要用处为了记录该树的顶点数)

为什么要知道该树的顶点数?
因为我们不知道以 last 为根的子树的顶点数,所以需要 ( |该树的顶点数| - |n连接的除last外所有顶点为根时顶点数和| - 1) 来求
(那个 1 是指被删除的重心)
如图:即我们从 f 出发,遍历到 n ,此时我们不知道以 last 为根的子树的顶点数
请添加图片描述

这样我们通过比较,便可以得到当以 n 为重心时,最大子树的顶点数的大小
又因为我们是沿着最大子树走,所以以 last 为重心时,以 n 为根的顶点数就是最大子树的顶点数的大小

所以我们可以通过比较两个顶点数大小,来判断是否往下继续走

int find_core(int n, int last, int f)
{
    int flag = -1, c = 0;
    for(int i = 0; i < data[n].size(); i++)
    {
        int w = data[n][i].to;
        if(w != last && !delete_point[w])
        {
            if(flag == -1)
            {
                flag = i;
            }
            else
            {
                if(count_point[w] > count_point[data[n][flag].to])
                {
                    flag = i;
                }
            }
            c += count_point[w];

        }
    }
    if(flag == -1 && last == -1)
    {
        return n;
    }
    else if(flag == -1)
    {
        return last;
    }
    c = max(count_point[data[n][flag].to], count_point[f] - c - 1);
    if(c >= count_point[n])
    {
        return last;
    }
    return find_core(data[n][flag].to, n, f);
}

第②种方法:
遍历所有点,不光沿着最大子树走,这样复杂度也只是O(n),不会超时,并且代码相对简单
(我们最终取递归结果的 .second 即可)

pair<int, int> find_core(int v, int p, int t)
{
    pair<int, int> res = make_pair(INF, -1);
    int s = 1, m = 0;
    for(int i = 0; i < data[v].size(); i++)
    {
        int w = data[v][i].to;
        if(w == p || delete_point[w]) continue;
        res = min(res, find_core(w, v, t));
        m = max(m, count_point[w]);
        s += count_point[w];
    }
    m = max(m, t - s);
    res = min(res, make_pair(m, v));
    return res;
}

👌

AC代码

(如果还有不会的地方,可以评论或者私信,我们可以一起探讨)

#include <iostream>
#include <stdio.h>
#include <vector>
#include <algorithm>
#define INF 1000000005
using namespace std;
struct edge
{
    int to, length;
    edge(int t, int l)
    {
        to = t; length = l;
    }
};
vector<struct edge> data[10005];
int N, K;
int count_point[10005];
bool delete_point[10005];
void add_edge(int from, int to, int length)
{
    data[from].push_back(edge(to, length));
    data[to].push_back(edge(from, length));
}
void tag_point(int n, int last)
{
    int ans = 1;
    for(int i = 0; i < data[n].size(); i++)
    {
        int w = data[n][i].to;
        if(last != w && !delete_point[w])
        {
            tag_point(w, n);
            ans += count_point[w];
        }
    }
    count_point[n] = ans;
}
int find_core(int n, int last, int f)
{
    int flag = -1, c = 0;
    for(int i = 0; i < data[n].size(); i++)
    {
        int w = data[n][i].to;
        if(w != last && !delete_point[w])
        {
            if(flag == -1)
            {
                flag = i;
            }
            else
            {
                if(count_point[w] > count_point[data[n][flag].to])
                {
                    flag = i;
                }
            }
            c += count_point[w];

        }
    }
    if(flag == -1 && last == -1)
    {
        return n;
    }
    else if(flag == -1)
    {
        return last;
    }
    c = max(count_point[data[n][flag].to], count_point[f] - c - 1);
    if(c >= count_point[n])
    {
        return last;
    }
    return find_core(data[n][flag].to, n, f);
}
void measure(int n, int p, int l, vector<int> &d)
{
    d.push_back(l);
    for(int i = 0; i < data[n].size(); i++)
    {
        int w = data[n][i].to;
        if(w != p && !delete_point[w])
        {
            measure(w, n, l + data[n][i].length, d);
        }
    }
}
int count_num(vector<int> &d)
{
    int res = 0, num = d.size();
    sort(d.begin(), d.end());
    int j = num - 1;
    for(int i = 0; i < num && i < j; i++)
    {
        for(; d[i] + d[j] > K && i < j; j--){}
        res += j - i;
    }
    return res;
}
int solve(int n)
{
    int res = 0;
    tag_point(n, -1);
    int core = find_core(n, -1, n);
    delete_point[core] = true;
    for(int i = 0; i < data[core].size(); i++)
    {
        int w = data[core][i].to;
        if(!delete_point[w])
        {
            res += solve(data[core][i].to);
        }
    }
    vector<int> point_sum;
    point_sum.push_back(0);
    for(int i = 0; i < data[core].size(); i++)
    {
        int w = data[core][i].to;
        if(!delete_point[w])
        {
            vector<int> point;
            measure(data[core][i].to, core, data[core][i].length, point);
            res -= count_num(point);
            point_sum.insert(point_sum.end(), point.begin(), point.end());
        }
    }
    delete_point[core] = false;
    res += count_num(point_sum);
    return res;
}
int main()
{
    while(1)
    {
        scanf("%d %d", &N, &K);
        if(N == 0 && K == 0)
        {
            break;
        }
        for(int i = 0; i < N; i++)
        {
            data[i].clear();
        }
        fill(delete_point, delete_point + N, false);
        for(int i = 0; i < N - 1; i++)
        {
            int a, b, l;
            scanf("%d %d %d", &a, &b, &l);
            add_edge(a - 1, b - 1, l);
        }
        printf("%d\n", solve(0));
    }
}