前端算法第一七九期-树中距离之和

100 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第12天,点击查看活动详情

给定一个无向、连通的树。树中有 n 个标记为 0...n-1 的节点以及 n-1 条边 。

给定整数 n 和数组 edgesedges[i] = [ai, bi] 表示树中的节点 aibi 之间有一条边。

返回长度为 n 的数组 answer ,其中 answer[i] 是树中第 i 个节点与所有其他节点之间的距离之和。

示例 1:

输入: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释: 树如图所示。
我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5) 
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。

示例 2:

输入: n = 1, edges = []
输出: [0]

示例 3:

输入: n = 2, edges = [[1,0]]
输出: [1,1]

树形动态规划

首先我们来考虑一个节点的情况,即每次题目指定一棵树,以 root\textit{root} 为根,询问节点 root\textit{root} 与其他所有节点的距离之和。

很容易想到一个树形动态规划:定义 dp[u]\textit{dp}[u] 表示以 uu 为根的子树,它的所有子节点到它的距离之和,同时定义 sz[u]\textit{sz}[u] 表示以 uu 为根的子树的节点数量。

其中 son[u]\textit{son}[u] 表示 uu 的所有后代节点集合。转移方程表示的含义就是考虑每个后代节点 vv,已知 vv 的所有子节点到它的距离之和为 dp[v]\textit{dp}[v]dp[v],那么这些节点到 uu 的距离之和还要考虑uv u\rightarrow v 这条边的贡献。考虑这条边长度为 1,一共有 sz[v]sz[v] 个节点到节点 u 的距离会包含这条边,因此贡献即为 1×sz[v]=sz[v]1\times \textit{sz}[v]=\textit{sz}[v]。我们遍历整棵树,从叶子节点开始自底向上递推到根节点 root\textit{root} 即能得出最后的答案为 dp[root]\textit{dp}[\textit{root}]

假设 uu 的某个后代节点为 vv,如果要算 vv 的答案,本来我们要以 vv 为根来进行一次树形动态规划。但是利用已有的信息,我们可以考虑树的形态做一次改变,让 vv 换到根的位置,uu 变为其孩子节点,同时维护原有的 dp\textit{dp}dp 信息。在这一次的转变中,我们观察到除了 uuvvdp\textit{dp} 值,其他节点的 dp\textit{dp} 值都不会改变,因此只要更新 dp[u]\textit{dp}[u]dp[v]\textit{dp}[v] 的值即可。

let ans, sz, dp, graph;
const dfs = (u, f) => {
    sz[u] = 1;
    dp[u] = 0;
    for (const v of graph[u]) {
        if (v === f) {
            continue;
        }
        dfs(v, u);
        dp[u] += dp[v] + sz[v];
        sz[u] += sz[v];
    }
}
const dfs2 = (u, f) => {
    ans[u] = dp[u];
    for (const v of graph[u]) {
        if (v === f) {
            continue;
        }
        const pu = dp[u], pv = dp[v];
        const su = sz[u], sv = sz[v];

        dp[u] -= dp[v] + sz[v];
        sz[u] -= sz[v];
        dp[v] += dp[u] + sz[u];
        sz[v] += sz[u];

        dfs2(v, u);

        dp[u] = pu, dp[v] = pv;
        sz[u] = su, sz[v] = sv;
    }
}
var sumOfDistancesInTree = function(n, edges) {
    ans = new Array(n).fill(0);
    sz = new Array(n).fill(0);
    dp = new Array(n).fill(0);
    graph = new Array(n).fill(0).map(v => []);
    for (const [u, v] of edges) {
        graph[u].push(v);
        graph[v].push(u);
    }
    dfs(0, -1);
    dfs2(0, -1);
    return ans;
};