2581. 统计可能的树根数目 【换根DP】

42 阅读1分钟

2581. 统计可能的树根数目

image.png

image.png

class Solution:
    def rootCount(self, edges: List[List[int]], guesses: List[List[int]], k: int) -> int:
        n = len(edges) + 1
        g = [[] for _ in range(n)]
        for u, v in edges:
            g[u].append(v)
            g[v].append(u)

        guesses = {(u, v) for u, v in guesses} # 转成 哈希表


        # 统计 根节点为 0 时的 猜测正确次数
        def dfs(u, pa): # u 的 parent结点 为pa
            nonlocal cnt0
            for v in g[u]:
                if v != pa:
                    cnt0 += (u, v) in guesses
                    dfs(v, u)

        cnt0 = 0
        dfs(0, -1)

        def reroot(u, pa, cnt): # 换根 DP
            nonlocal res 
            res += cnt >= k 
            for v in g[u]:
                if v != pa:
                    reroot(v, u, cnt - ((u, v) in guesses) + ((v, u) in guesses))

        res = 0
        reroot(0, -1, cnt0)
        return res 

参考链接

class Solution {
public:
    int rootCount(vector<vector<int>>& edges, vector<vector<int>>& guesses, int k) {
        int n = edges.size() + 1;
        // 建图
        vector<vector<int>> g(n);
        for (auto e : edges){
            int u = e[0], v = e[1];
            g[u].emplace_back(v);
            g[v].emplace_back(u);
        }

        set<pair<int, int>> guess;
        for (auto e : guesses){
            int u = e[0], v = e[1];
            guess.insert(make_pair(u, v));
        }
        
        int cnt0 = 0;
        function<void(int, int)> dfs = [&](int u, int pa){
            for (int v : g[u]){
                if (v != pa){
                    cnt0 += (guess.find({u, v}) != guess.end());   
                    dfs(v, u);                
                }    
            }
        };

        dfs(0, -1);

        int res = 0;
        function<void(int, int, int)> reroot = [&](int u, int pa, int cnt){
            res += cnt >= k;
            for (int v : g[u]){             
                if (v != pa){
                    reroot(v, u, cnt - (guess.find({u, v}) != guess.end()) + (guess.find({v, u}) != guess.end()));
                }
            }
        };

        reroot(0, -1, cnt0);
        return res;
    }
};