2867. 统计树中的合法路径数目 【埃氏筛标记质数、符合条件的树路径数统计】

105 阅读2分钟

2867. 统计树中的合法路径数目

每日链接

the sieve of Eratosthenes - 埃拉托斯特尼筛法

class Solution:
    def countPaths(self, n: int, edges: List[List[int]]) -> int:
        # 深度 优先搜索

        # 要点 1: 快速判断 一个数 是不是 质数  埃氏筛。
        # 要点 2: 以 质数结点 为根, 搜索所有 非质数的子树 大小

        # 预处理, 对结点编号 1 - n 质数判断
        is_prime = [True] * (n + 1) 
        is_prime[1] = False
        for i in range(2, isqrt(n) + 1):
            if is_prime[i]: # 发现质数 x, 其倍数是合数 2*x 3*x 4*x
                for j in range(i * i, n + 1, i) :
                    is_prime[j] = False 

        # 建图
        G = [[] for _ in range(n + 1)]
        for u, v in edges:
            G[u].append(v)
            G[v].append(u)

        # 深度 优先搜索 
        def dfs(u, pa): 
            nodes.append(u)
            for v in G[u]:  # 隐含的结束递归的条件
                if v != pa and not is_prime[v]: # 递归统计 非质数结点 个数
                    dfs(v, u) 

        res = 0
        cnt = [0] * (n + 1)  # 包含 结点 i 的非素数路径 上的节点数
        for i in range(1, n + 1): # 依次 以 i 为 根  且只讨论 i 为 质数的情况
            if not is_prime[i]:
                continue  # 跳过 质数
            cur = 0
            for j in G[i]:
                if is_prime[j]:
                    continue 
                if cnt[j] == 0: # j 为 非质数  且 未统计过
                    nodes = [] # 统计 和 j 相连的 非质数路径 结点个数
                    dfs(j, 0)
                    for node in nodes:
                        cnt[node] = len(nodes)

                res += cnt[j] * cur 
                cur += cnt[j]  # 
            res += cur #  
        return res 
        # dp[i][0]: 以 i 为根结点 不包含质数 的路径数。 dp[i][1] 包含一个质数
        # i 不为质数  dp[i][0] = sum(dp[child][0]) + 1           dp[i][1] = sum(dp[i][1])
        # i 为质数  dp[i][0] = 0       dp[i][1] = sum(dp[child][0]) + 1

C++

// // 根据数据量级  进行数据预处理  埃氏筛 标记素数
// const int N = 1e5 + 1;
// bool is_prime[N]; // 质数=true, 非质数 false
// int init = [](){
//     // 埃氏筛
//     fill(begin(is_prime), end(is_prime), true);
//     is_prime[1] = false;
//     for (int i = 2; i * i < N; ++i){
//         if (is_prime[i]){
//             for (int j = i * i; j < N; j += i){
//                 is_prime[j] = false;
//             }
//         }
//     }
//     return 0;
// }();



class Solution {
public:
    long long countPaths(int n, vector<vector<int>>& edges) {
        bool is_prime[n + 1];
        memset(is_prime, true, sizeof(is_prime)); // 初始化 为 true
        is_prime[1] = false; // 非质数
        for (int i = 2;  i * i <= n; ++i){
            if (is_prime[i]){
                for (int j = i * i; j <= n; j += i){
                    is_prime[j] = false;
                }
            }
        }

        // 建图
        vector<vector<int>> G(n + 1);
        for (auto edge : edges){
            int u = edge[0];
            int v = edge[1];
            G[u].emplace_back(v);
            G[v].emplace_back(u);
        }

        vector<int> cnt(n + 1, 0);  // 结点 i 附近 的 非质数结点 个数
        vector<int> nodes; // 临时存储 非质数 结点,便于 个数统计
        function<void(int, int)> dfs = [&](int u, int pre){
            nodes.emplace_back(u);
            for (int v : G[u]){
                if (v != pre && !is_prime[v]){
                    dfs(v, u);
                }
            }
        };

        long long res = 0;
        for (int i = 1; i <= n; ++i){// 遍历 质数
            if (!is_prime[i]) continue;
            long long cur = 0;
            for (int j : G[i]){// 计算 邻近的 非质数个数
                if (is_prime[j]) continue;
                if (cnt[j] == 0){// 未计算过
                    nodes.clear();
                    dfs(j, 0);
                    for (int node : nodes){
                        cnt[node] = nodes.size();
                    }
                }

                // 之前 计算过了
                res += (long long) cnt[j] * cur;
                cur += cnt[j];
            }
            res += cur;
        }
        return res;
    }
};