【力扣roadmap】834. 树中距离之和

13 阅读2分钟

题目描述

image.png

思路

这题算术评级很高,但其实没那么难。 题目要求ans[i]是第i个节点与其他所有节点的距离和。 你先别想太多,你就先想这样的场景,一个根下面挂着三个子树,然后考虑ans[0]答案是多少。假设你知道0下面的三个子节点的答案ans[1],ans[2],ans[3]. 那么怎么转移到ans[0]呢?比如ans[1],就代表其下所有节点到1号位置的距离和。那么1号子树所有节点0号位置的距离和,是不是等于,ans[1] + size[1],其中size[1]是1号子树的节点数量(你想象ans[1]就是1下面所有节点走路走到节点1的代价总和;那么这群节点已经走到1了,再全部走到0的代价不就是基于ans[1]再加上size[1]么?) 这样一来,ans[0]就呼之欲出了

ans[0]=ans[child]+size[child]ans[0] = \sum ans[child] + size[child]

其中child0号节点的儿子

可见这是父亲的答案基于儿子算出,所以自底向上跑一遍dfs。这时候,你可以得到ans[0]和size数组。 size数组是以0为全局根得到的,size[i]表示以i为根节点的子树的大小。

你可以通过ans[0]和size数组推得全部ans数组。

我们先思考一下ans[1]怎么得到(这里假设1是0的直系儿子,ans[1]表示全部节点到1的距离)?

首先,1子树的全部节点不需要全部走1节点后,再走一步到0了,因为求ans[1],终点就是1节点. 所以ans[1]的大小,相比ans[0]来说,少了size[1].

其次,除了1子树的其余所有(n-size[1])个节点,走到0节点之后都得再多走一步走到1,所以ans[1]的大小,相比ans[0]来说,多了(n-size[1])。 因此,ans[1] = ans[0] - size[1] + n - size[1]. 推广得到

ans[child]=ans[dad]size[child]+nsize[child]ans[child] = ans[dad] - size[child] + n - size[child]

其中child是dad的子节点。可见这是自上而下的递推。再跑一遍dfs2就好了。 总时间复杂度O(n).

image.png

代码

class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
        
        if n == 1 : return [0]

        g = [[] for _ in range(n)]

        for a , b in edges :
            g[a].append(b)
            g[b].append(a) 

        
        dp = [0] * n 
        size = [0] * n # 以0为根  i为根的子树的大小
        
        def dfs(rt , dad) -> int : # 返回以rt为根节点的子树的size

            if len(g[rt]) == 1 and g[rt][0] == dad :
                dp[rt] = 0
                size[rt] = 1 
                return size[rt]
            
            subtree = 0
            size[rt] = 1 
            for c in g[rt]:

                if c == dad : continue 
                
                c_size = dfs(c,rt)
                print(c,rt,c_size)
                size[rt] += c_size
                dp[rt] += c_size + dp[c] 
            
            return size[rt]
        
        dfs(0,-1) 
        # print(dp)

        ans = [0] * n 
        ans[0] = dp[0]
        
        def dfs2(rt,dad) :
            nonlocal n
            for c in g[rt] :
                if c == dad :
                    continue 
                ans[c] = ans[rt] - size[c] + n - size[c]
                dfs2(c,rt) 
            
        dfs2(0,-1)
        return ans