#补充算法知识:树形DP(3道例题,逐行代码详解)

157 阅读3分钟

开启掘金成长之旅!这是我参与「掘金日新计划·2月更文挑战」的第20天,点击查看活动详情

树形DP

树形dp一般先算子树然后进行合并,在实现上与树的后序遍历有相似,如果不是二叉树,就根树的后序遍历不一样了。

原理:遍历子树(常用到DFS和BFS,建议能用BFS就不要用DFS,C++的栈空间较小,DFS容易爆栈),遍历完后把子树的值合并给父节点。

给定一棵有N个节点的树(通常是无根树,也就是有N-1条无向边),我们可以任选一个节点为根节点,从而定义出每个节点的深度和每颗子树的根。树形dp中,一般就以节点从深到浅(子树从小到大)的顺序作为dp的“节点”。在状态表示中,第一维通常是节点编号(代表以该节点为根的子树)

树形dp一般是由递归的方式实现。先算子树然后进行合并:对于每个节点x,先递归在它的每个子结点上进行dp,在回溯时,从子节点向节点x进行状态转移。

引入:给定一棵n个点的树(1号点位根节点),求以点 i 为根的子树的大小。

f[i] 以点 i 为根的子树的点的个数

f[i]=1+f[k]f[i] = 1 + \sum f[k](k是 i 的儿子)

遍历图的伪代码:

遍历的伪代码:
void dfs(i) {
    if (为叶节点) {
    	f[i] = 1;
    	return;
    }
    for (k是i的儿子) {
        dfs(k);
        f[i] += f[k];
    }
    f[i] += 1;
}

存图方式:链式前向星

int e[N], h[N], ne[N], idx; //链式前向星存图
void add(int a, int b) { e[idx] = b, ne[idx] = h[a], h[a] = idx++; }

NC 51178. 没有上司的舞会

image.png

image.png

const int N = 6010;
bool hp[N];
int n, happy[N], f[N][N], e[N], h[N], ne[N], idx;;
void add(int a, int b) { e[idx] = b, ne[idx] = h[a], h[a] = idx++; }  //链式前向星存图
void dfs(int u) {        //递归一下树
    f[u][1] = happy[u];
    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        dfs(j);
        f[u][0] += max(f[j][0], f[j][1]);
        f[u][1] += f[j][0];
    }
}
int main() {
    cin >> n; 
    for (int i = 1; i <= n; i++) cin >> happy[i];
    memset(h, -1, sizeof h);
    for (int i = 0; i < n - 1; i++) {    //将边加到图里,并且找出根节点
        int a, b;
        cin >> a >> b;
        hp[a] = true;
        add(b, a);
    }
    int root = 1;
    while (hp[root]) root++;
    dfs(root);//递归一下
    cout << max(f[root][0], f[root][1]) << endl;
    return 0;
}

AcWing 1072. 树的最长路径

AcWing 1072. 树的最长路径 题面

题解:

任取一个点为根节点,这里取1号点为根节点

因为是无向边,所以建图的时候,建的是双向的

所以在dfs的时候要判断一下,当前走的方向是哪个方向,只能向下搜索,不能向上搜索

因为根节点没有父节点,所以传入的时候只需要传入一个不存在的数字就可以了,这里用-1

d代表长度,dfs(当前节点,当前节点的父节点),当走到j之后,u就是j的父节点,w[i]为当前边的权值

只能是往下走,不能是往father节点的方向走

const int N = 10010, M = N * 2;
int n, h[N], e[M], w[M], ne[M], idx, ans;
void add(int a, int b, int c) {
    e[idx] = b;
    w[idx] = c;
    ne[idx] = h[a];
    h[a] = idx++;
}
int dfs (int u, int father) {    //u是当前节点,father是当前节点的父节点
    int dist = 0;                //表示从当前点往下走的最大长度
    int d1 = 0, d2 = 0;          //d1表示最大长度,d2表示次大长度
    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        if (j == father) continue;            //只能是往下走,不能是往father节点的方向走
        int d = dfs(j, u) + w[i];
        dist = max(dist, d);                  //更新一下从当前点往下走的最大长度
        if (d >= d1) {                        //更新一下最大值和次大值的 
            d2 = d1;
            d1 = d;
        } else if (d > d2) d2 = d;
    }
    ans = max(ans, d1 + d2);    //更新一下 路径两端的点的最远距离
    return dist;
}
int main() {
    cin >> n; 
    memset(h, -1, sizeof h);
    for (int i = 0; i < n - 1; i++) {
        int a, b, c, cin >> a >> b >> c;
        add(a, b, c);     //建无向边
        add(b, a, c);     //建无向边
    }
    dfs(1, -1);
    cout << ans << endl;
    return 0;
}

AcWing 1073. 树的中心

在这里插入图片描述

在这里插入图片描述

求出每个点到当前树的最远点的距离

  1. 往子节点走的
  2. 往父节点走的

两遍树形dp

需要判断一下,向上走的方向中,父节点这个最长边是不是从当前这个子树上来的。

d1是往下走的最长路径,d2是往下走的次长路径,p1,p2记录上一个节点是哪个,up是向上走的最长路径

(链式前向星存图)

找出当前节点往下走的最长路径距离和次长路径距离

更新最大值和次大值,并记录最大值和次大值所在的路径,上一个节点。

如果是叶子节点,就证明这个点从来没有被更新过,就是-INF,叶子节点往下走就没有路了,就是0

往上走有两个方式:

  1. 如果最长路径是从子节点过来的,那么向上走的最长路径就只能用次大值来更新
  2. 如果最长路径不是从子节点过来的,那么向上走的最长路径就用最大值来更新
const int N = 10010, M = N + N, INF = 0x3f3f3f3f;
int n, h[N], e[M], w[M], ne[M], idx;
int d1[N], d2[N], p1[N], p2[N], up[N];  //d1是往下走的最长路径,d2是往下走的次长路径,p1,p2记录上一个节点是哪个,up是向上走的最长路径
void add(int a, int b, int c) { e[idx] = b; w[idx] = c; ne[idx] = h[a]; h[a] = idx++; }  //链式前向星
int dfs_d (int u, int father) {
    d1[u] = d2[u] = -INF;              //初始化距离为负无穷
    for (int i = h[u]; i != -1; i = ne[i]) {    //枚举当前节点的所有子节点
        int j = e[i];
        if (j == father) continue;
        int d = dfs_d(j, u) + w[i];
        if (d >= d1[u]) {点。
            d2[u] = d1[u];
            d1[u] = d;
            p2[u] = p1[u];
            p1[u] = j;
        } else if (d > d2[u]) {
            d2[u] = d;
            p2[u] = j;
        }
    }
    if (d1[u] == -INF) d1[u] = d2[u] = 0;
    return d1[u];
}

void dfs_u (int u, int father) {
    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        if (j == father) continue;
        //往上走有两个方式,
        if (p1[u] == j) up[j] = max(up[u], d2[u]) + w[i];    //如果最长路径是从子节点过来的,那么向上走的最长路径就只能用次大值来更新
        else up[j] = max(up[u], d1[u]) + w[i];     //如果最长路径不是从子节点过来的,那么向上走的最长路径就用最大值来更新
        dfs_u(j, u);
    }
}
int main() {    
    cin >> n;
    memset(h, -1, sizeof h);
    for (int i = 0; i < n - 1; i++) {
        int a, b, c, cin >> a >> b >> c;
        add(a, b, c);
        add(b, a, c);
    }
    dfs_d(1, -1);
    dfs_u(1, -1);      //求一遍往上走的最大路径
    int res = INF;     //枚举每个点
    for (int i = 1; i <= n; i++) res = min(res, max(d1[i], up[i]));
    cout << res << endl;
    return 0;
}