树形DP之换根DP

296 阅读3分钟

换根DP主要解决这样一类问题,树的根节点不固定,随着根节点的改变,树的深度和、点权和等属性会产生变化,求相关的属性值。

以(834. 树中距离之和 - 力扣(Leetcode))为例,该题求各个节点到其他节点的距离之和。

834.png

基本的思路进行两次DFS,第一次DFS先以其中一个节点为根节点,计算这个节点到其他节点的距离和,并保存子节点数量信息,第二次DFS通过动态规划计算其相邻节点到其他节点的距离和,再递归地计算其他节点。

首先要进行预处理,将图存在邻接表中,可以用map实现邻接表,key是父节点,value是子节点组成的数组。因为本题的元素是从0...n-1,因此也可以用数组实现邻接表。

table := make(map[int][]int)
for _, edge := range edges {  
    if len(edge) == 0 {  
        return []int{0}  
    }  
    table[edge[0]] = append(table[edge[0]], edge[1])  
    table[edge[1]] = append(table[edge[1]], edge[0])  
}

接着就可以开始第一次DFS。如上图所示,在本题中,可以以0作为根节点,用一个距离d表示当前节点到根节点的距离,自顶向下递归,每次递归,距离d都要+1,并将距离d加到总距离dis中。除了计算总距离,还要自底向上计算每个节点的子节点的个数,最先确定的是子节点只有自身的节点,son=1,再逐层将子节点个数加到上一层。

var dfs func(father int, now int, d int)  
dis := 0  
sons := make([]int, n)  
dfs = func(father int, now int, d int) {  
    t := table[now]  
    d++  
    sons[now] = 1  
    for i := 0; i < len(t); i++ {  
        //排除父节点
        if t[i] != father {  
            dis += d  
            dfs(now, t[i], d)  
            //将下一层的子节点个数加到这一层
            sons[now] += sons[t[i]]  
        }  
    }  
}  
//父节点为-1,根节点为0,根节点距离为0  
dfs(-1, 0, 0)

第二次DFS,要先计算0的相邻节点1和2的距离和,再往下递归。以节点2为例,设根节点0的距离和为dis0,所有节点2的子节点,包括节点2本身的距离都要-1,在本例中,节点2,3,4,5的距离减1;所有非节点2子节点的节点,包括原根节点,距离都要+1,在本例中,0,1的距离+1。一般地,设原节点的距离和为disodis_o,当前节点的距离和为disndis_n,当前节点的子节点数量为son,节点总数为n,则可以得到状态转移方程:

disn=disoson+nsondis_n = dis_o - son + n - son

在代码实现中,设置一个结果数组res[n],每递归到一个节点,则将相应的距离和更新到数组对应位置中。

res := make([]int, n)  
res[0] = dis  
var reRoot func(fatherDis int, father int, now int)  
reRoot = func(fatherDis int, father int, now int) { 
    t := table[now]  
    if now != 0 {  
        res[now] = fatherDis - sons[now] + n - sons[now] 
    }  
    for i := 0; i < len(t); i++ {  
        if father != t[i] {  
            reRoot(res[now], now, t[i])  
        }  
    }  
}    
reRoot(dis, -1, 0)

完整代码如下:

func sumOfDistancesInTree(n int, edges [][]int) []int {
    table := make(map[int][]int)
    for _, edge := range edges {  
        if len(edge) == 0 {  
            return []int{0}  
        }  
        table[edge[0]] = append(table[edge[0]], edge[1]) 
        table[edge[1]] = append(table[edge[1]], edge[0]) 
    }
    
    var dfs func(father int, now int, d int)  
    dis := 0  
    sons := make([]int, n)  
    dfs = func(father int, now int, d int) {  
        t := table[now]  
        d++  
        sons[now] = 1  
        for i := 0; i < len(t); i++ {  
            //排除父节点
            if t[i] != father {  
                dis += d  
                dfs(now, t[i], d)  
                //将下一层的子节点个数加到这一层
                sons[now] += sons[t[i]]  
            }  
        }  
    }  
    //父节点为-1,根节点为0,根节点距离为0  
    dfs(-1, 0, 0)
    
    res := make([]int, n)  
    res[0] = dis  
    var reRoot func(fatherDis int, father int, now int) 
    reRoot = func(fatherDis int, father int, now int) { 
        t := table[now]  
        if now != 0 {  
            res[now] = fatherDis - sons[now] + n - sons[now] 
        }  
        for i := 0; i < len(t); i++ {  
            if father != t[i] {  
                reRoot(res[now], now, t[i])  
            }  
        }  
    }    
    reRoot(dis, -1, 0)
    
    return res
}

每次DFS,图中所有的节点都遍历一次,时间复杂度为O(2n)=O(n)。