楼教主男人八题(第二题)

152 阅读6分钟

今天来搞1741 Tree

男人进度:2/8

题目链接

poj.org/problem?id=…

题目描述

Give a tree with nn vertices,each edge has a length(positive integer less than 1001).

The defination of dist(u,v)dist(u,v) is The min distance between node uu and vv.

Give an integer kk,for every pair (uu,vv) of vertices is called valid if and only if dist(u,v)dist(u,v) not exceed kk.

Write a program that will count how many pairs which are valid for a given tree.

输入

The input contains several test cases.

The first line of each test case contains two integers nn, kk. (n<=10000) The following n1n-1 lines each contains three integers uu,vv,ll, which means there is an edge between node uu and vv of length ll.

The last test case is followed by two zeros.

输出

For each test case output the answer on a single line.

样例输入

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

样例输出

8

样例解释

样例给出的树如上图示,共有8条符合要求的边:

  • 1->2,距离为 3
  • 1->3,距离为 1
  • 1->4,距离为 2
  • 1->3->5,距离为 1+1=2
  • 2->1->3,距离为 3+1=4
  • 3->5,距离为 1
  • 3->1->4,距离为 1+2=3
  • 4->1->3->5,距离为 2+1+1=4

题解

知识点树的重心

因为题目给出的是无向边,那么把哪个节点当作根节点都无所谓了。那么不妨设有一棵以节点1为根的树,长成下面这个样子:

仔细观察,我们可以依据是否包含根节点,把路径分为两类:

  • 包含根节点(比如 1->2,4->1->5)
  • 不包含根节点(比如 2->3, 5->6)

不难发现,不包含的根节点的路径,其所包含的节点必然都来自于同一棵子树。这就很符合递归的要求了,为了描述清楚,我们不妨先设:

  • 以节点ii 为根的子树的合法路径总数为 t(i)t(i)
  • 以节点ii 为根的子树的包含节点ii的合法路径总数为 f(i)f(i)

t(i)t(i) 可以递归的定义为:

t(i)={0,i是叶子si的儿子t(s)+f(i),i不是叶子t(i)= \begin{cases} 0, i是叶子 \\ \sum_{s是i的儿子} t(s) + f(i), i不是叶子 \end{cases}

问题简化了,答案可表示为t(1)=i=1nf(i)t(1) = \sum_{i=1}^{n} f(i)。那么如何求解 f(i)f(i) 咧?

对于每个节点 ii,做对应子树遍历,计算出该子树中每个节点和节点 ii 的距离。

比如有上面这样一棵树,为了方便,我们不妨设边的权值都为 1。那么每棵子树上的节点到对应根节点的距离如下图示:

需要注意的是,真实场景中的权值并不相等,若想得到关于距离升序的数组,不得不进行一次排序。

不管怎么说,现在我们得到 nn 个距离的数组。我们不妨设距离的限制为 KK

现在对于每个数组,进行一次双指针遍历操作:

  • 首先指针 ll 指向第一个元素,指针 rr 指向最后一个元素,累加器 sum = 0。
  • 接着下述两个步骤,直到 l=rl = r
  • 如果 l+r<=Kl+r <= K,那么 sum+=(rl)sum += (r-l),并且 l++l++
  • 如果 l+r>Kl+r > K,那么 rr--

l+r>Kl+r>K 时,说明第 ll 个节点和第 rr 个节点的距离超限了,又因为数组是递增的,所以 rr-- 尝试用一个距根节点更近的节点与 ll 匹配。

l+r<=Kl+r<=K 时,意味着第 ll 个节点,与第 l+1,l+2,...,rl+1,l+2,...,r 个节点的距离均不超过 KK,且此时 r+1r+1ll 的距离定然是超限的,所以 sum +=(rl)sum\ += (r-l)。然后 l++l++ 继续计算下一个节点的贡献。

但此时引入了一个新问题,得出的 sumsum 掺进了一些杂质,以节点 11 为例:

K=4K=4,那么:

sum= 71 +72 +73 +74 +75 +76\begin{align*} sum = &\ 7-1 \\ &\ + 7-2 \\ &\ + 7-3 \\ &\ + 7-4 \\ &\ + 7-5 \\ &\ + 7-6 \\ \end{align*}

其中类似于:

  • 4->2->1->2->5
  • 3->1->3->6

这种重复经过某些点的路径也加进去了。

这些路径可以通过下述方法剔除,以节点 11 为例:

  • 首先计算出 sum=21sum = 21
  • 然后分别计算出两棵子树中的节点,到 11 的距离。

  • 通过双指针的计算方法,分别计算出两棵子树贡献的错误路径数量均为:
(31)+(32)=3(3-1) + (3-2) = 3
  • 可得包含节点 11 的路径总数应为 2132=1521-3*2=15

这样按照如上流程,可以计算出所有 f(i)f(i)。那么该题就到此为止了?男人八题就这?

接下来讨论下时间复杂度。从上述流程中不难看出,一棵有 nn 个节点的树,要构造 2n2*n 个距离数组:

  • 一个是到子树根节点的。
  • 一个是到子树根节点的父节点的。

那么每棵子树包含的节点数量就很重要了。考虑树退化成链表的情况:

那么,每棵树的节点数量为:

  • tree1tree_1nn 个节点
  • treeitree_ini+1n-i+1 个节点
  • treentree_n11 个节点

那么时间复杂度就是 O(n2)O(n^2)。。。

接下来,引入一个新知识点树的重心。当把重心当做根节点时,可以保证最大子树的也不会超过n2\frac{n}{2}个节点。计算过程很简单,首先确认一个节点当做根节点,不妨先选最小的节点当做根。再设 cnticnt_i 是以 ii 为根的子树的节点数,othioth_i 为该子树之外的节点数量。

比如这样一棵树:

  • cnt1=4,oth1=0cnt_1 = 4, oth_1 = 0
  • cnt2=3,oth2=1cnt_2 = 3, oth_2 = 1
  • cnt3=2,oth3=2cnt_3 = 2, oth_3 = 2
  • cnt4=1,oth4=3cnt_4 = 1, oth_4 = 3

找出 max(othi,max(cnts))max(oth_i, max(cnt_s)) 最小的节点 i,该节点就是树的重心,记为 centercenter

max(cnts),si的子节点max(cnt_s), s ∈ {i的子节点},表示 ii 的最大的子树。

othioth_i 的意义其实是把 ii 的当做整棵树的根时,多出来的那颗子树,这个子树的根节点,就是现在 ii 的父节点~

如下图所示,红色部分就是节点 5 的 othoth

这样找出来根节点 centercenter 保证了最大的子树也不会超过n2\frac{n}{2}。让我们把链表变长,来看一看效果。

首先找到了 center=4center=4

对于两棵子树,分别找到重心为 2 和 6。

不难发现,通过把重心当做根的方式,可以保证子树的规模最起码会缩减一半。这样整体的时间复杂度就讲到了 O(nlgn)O(n*\lg n)

#include <stdio.h>
#include <string.h>
#include <algorithm>

using namespace std;

struct Node {
  int v, w, next;
}edge[20001];
int edge_cnt = 0;
int head[10001];

bool erase_flag[10001];
int node_count[10001];
int focus_mark[10001];
int queue[10001];

void CountNode(int root, int pre) {
  node_count[root] = 1;
  for (int i = head[root]; i != -1; i = edge[i].next) {
    int next = edge[i].v;
    if (next != pre && erase_flag[next] == false) {
      CountNode(next, root);
      node_count[root] += node_count[next];
    }
  }
}

int GetCenter(int root) {
  int cand_root = -1;
  int cand_root_threshold = 0x3f3f3f3f;

  int l = 0, r = 0;
  queue[r++] = root;
  while (l < r) {
    int f = queue[l++];
    int max_subtree_node_num = 0;
    int total_subtree_node_num = 0;

    for (int i = head[f]; i != -1; i = edge[i].next) {
      int next = edge[i].v;
      if (focus_mark[next] != root && erase_flag[next] == false) {
        queue[r++] = next;
        focus_mark[next] = root;
        total_subtree_node_num += node_count[next];
        max_subtree_node_num = max(max_subtree_node_num, node_count[next]);
      }
    }

    max_subtree_node_num = max(node_count[root] - total_subtree_node_num - 1, max_subtree_node_num);

    if (max_subtree_node_num < cand_root_threshold) {
      cand_root = f;
      cand_root_threshold = max_subtree_node_num;
    }
  }

  return cand_root;
}

void GetDist(int root, int pre, int pre_dist, int k, int *dist, int &cnt) {
  if (pre_dist > k) {
    return;
  }
  dist[cnt++] = pre_dist;
  for (int i = head[root]; i != -1; i = edge[i].next) {
    int next = edge[i].v;
    int d = edge[i].w;
    if (next != pre && erase_flag[next] == false) {
      GetDist(next, root, pre_dist + d, k, dist, cnt);
    }
  }
}

int total_dist[10001];
int GetTotalPair(int root, int pre, int pre_dist, int k) {
  int dist_cnt = 0;
  GetDist(root, pre, pre_dist, k, total_dist, dist_cnt);

  sort(total_dist, total_dist + dist_cnt);

  int total_pair = 0;
  for (int l = 0, r = dist_cnt-1; l < r;) {
    if (total_dist[l] + total_dist[r] <= k) {
      total_pair += r-l;
      l++;
    } else {
      r--;
    }
  }
  return total_pair;
}

int DivideAndConquer(int root, int k) {
  // 将 root 更新为重心
  CountNode(root, 0);
  root = GetCenter(root);
  int total_pair = GetTotalPair(root, 0, 0, k);
  for (int i = head[root]; i != -1; i = edge[i].next) {
    int next = edge[i].v;
    int dist = edge[i].w;
    if (erase_flag[next] == false) {
      total_pair -= GetTotalPair(next, root, dist, k);
    }
  }

  erase_flag[root] = true;
  for (int i = head[root]; i != -1; i = edge[i].next) {
    int next = edge[i].v;
    if (erase_flag[next] == false) {
      total_pair += DivideAndConquer(next, k);
    }
  }
  return total_pair;
}

int main() {
  int n, k;
  while(scanf("%d %d", &n, &k) && (n || k)) {
    edge_cnt = 0;
    memset(head, -1, sizeof(int)*(n+1));
    for (int i = 1, u, v, w; i < n; i++) {
      scanf("%d %d %d", &u, &v, &w);
      edge[edge_cnt].v = v;
      edge[edge_cnt].w = w;
      edge[edge_cnt].next = head[u];
      head[u] = edge_cnt++;

      edge[edge_cnt].v = u;
      edge[edge_cnt].w = w;
      edge[edge_cnt].next = head[v];
      head[v] = edge_cnt++;
    }
    memset(erase_flag, 0, sizeof(bool)*(n+1));
    memset(focus_mark, 0, sizeof(int)*(n+1));
    printf("%d\n", DivideAndConquer(1, k));
  }
  return 0;
}