【每日一LeetCode】公共祖先求解 倍增 Tarjan

75 阅读12分钟

2846. 边权重均等查询

B 站讲解链接

参考链接:https://oi-wiki.org//graph/lca/#tarjan-%E7%AE%97%E6%B3%95

image.png

image.png

方法: 朴素 【Python 超时】

class Solution {
public:
    vector<int> minOperationsQueries(int n, vector<vector<int>>& edges, vector<vector<int>>& queries) 
    {
        vector<vector<pair<int, int>>> g(n);
        for (auto& edge: edges){
            int a = edge[0], b = edge[1], w = edge[2];
            g[a].emplace_back(b, w);
            g[b].emplace_back(a, w);
        }
        
//  DFS 处理模块, 求取 根结点 到 结点node 的 count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数 
        
        // 一些 辅助
        vector<int> temp(27); // 中间辅助列表 记录  权重 为 i 的个数 为 temp[i]
        vector<int> parent(n, -1); // 结点 编号 为 i 的 父母节点为 parent[i]  求LCA 要用
        vector<int> level(n);  // 记录 结点 i 所在 的层编号   求LCA 要用

        vector<vector<int>> count(n, vector<int>(27, 0));  // count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数
        function<void(int)> dfs = [&](int cur) {//  count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数 
            for (auto [next_node, w]: g[cur]){// {next_node, weight}
                if (next_node != parent[cur]){
                    temp[w] += 1;
                    for (int i = 1; i <= 26; i++)
                        count[next_node][i] = temp[i];
                    
                    parent[next_node] = cur;
                    level[next_node] = level[cur] + 1;
                    
                    dfs(next_node);
                    temp[w] -= 1; // 恢复  因为 每个 结点 到 根结点的情况 不一样
                }
            }
        };

        dfs(0);  // 根据 结点编号 递归 依次处理 
        
// 获取 LCA 模块  
        function<int(int, int)> getLCA = [&](int p, int q){
            while (1){
                if (level[p] > level[q]){
                    p = parent[p]; // 更深的上移
                }
                else if (level[p] < level[q]){
                    q = parent[q];
                }
                else if (p == q){ // 找到了
                    return p;
                }
                else{
                    p = parent[p];
                    q = parent[q];
                }            
            }    
            return 0;
        };
        
        
//   处理 咨询  并 返回 结果
        vector<int> res;        
        for (auto query: queries)
        {
            int a = query[0], b = query[1];
            int lca = getLCA(a, b);
            
            // 计算 改成 一样需要的操作数   把 剩下的都改成 数量最多的那个数  边个数 - 阈值一样的最多边数
            int total = 0; //  总的边数   因为 可以修改成任意值, 所以 每条边 最多需要修改 1 次
            int mx = 0;  //  阈值一样 的边数 最大
            for (int i = 1; i <= 26; i++){
                int temp = count[a][i] + count[b][i] - 2 * count[lca][i];;
                total += temp;
                mx = max(mx, temp);
            }
            
            res.push_back(total - mx);
        }
        
        return res;
    }
      
};
class Solution:
    def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        g = [[] for _ in range(n)]
        for u, v, w in edges:
            g[u].append((v, w))
            g[v].append((u, w))

## DFS 处理 部分

        temp = [0] * 27  # 权重 为 i 的边数为 temp[i]
        parent = [-1] * n  #结点 i 的 父母结点 为 parent[i]
        level = [0] * n   # 记录 结点 i 所在 的层编号

        
        cnt = [[0 for j in range(27)] for _ in range(n)]  # count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数 
        def dfs(cur):
            for next_node, w in g[cur]:
                if next_node != parent[cur]:
                    temp[w] += 1
                    for i in range(1, 27):
                        cnt[next_node][i] = temp[i]
                
                    parent[next_node] = cur
                    level[next_node] = level[cur] + 1

                    dfs(next_node)
                    temp[w] -= 1

        dfs(0)

##      
        def getLCA(a, b):
            while True: # 必须 这样 
                if level[a] > level[b]:
                    a = parent[a]
                elif level[a] < level[b]:
                    b = parent[b]
                elif a == b:
                    return a
                else:
                    a = parent[a]
                    b = parent[b]

        # 处理 咨询 并返回 结果

        res = []
        for a, b in queries:
            lca_node = getLCA(a, b)
            total = 0
            mx = 0
            for i in range(1, 27):
                temp = cnt[a][i] + cnt[b][i] - cnt[lca_node][i] * 2
                total += temp 
                mx = max(mx, temp)

            res.append(total - mx)

        return res 

方法二: 倍增思路

class Solution {
public:
    vector<int> minOperationsQueries(int n, vector<vector<int>>& edges, vector<vector<int>>& queries) 
    {
        vector<vector<pair<int, int>>> g(n);
        for (auto& edge: edges){
            int a = edge[0], b = edge[1], w = edge[2];
            g[a].emplace_back(b, w);
            g[b].emplace_back(a, w);
        }
        
//  BFS 处理, 求取 根结点 到 结点node 的 count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数 
        
        // 一些 辅助
        vector<int> temp(27); // 中间辅助列表 记录  权重为 i 的边数 为 temp[i]
        vector<int> parent(n, -1); // 结点 编号 为 i 的 父母节点为 parent[i]
        vector<int> level(n);  // 记录 结点 i 所在 的层编号

        vector<vector<int>> count(n, vector<int>(27, 0));  // count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数
        
        function<void(int)> dfs = [&](int cur) {//  count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数 
            for (auto [next_node, w]: g[cur]){// {next_node, weight}
                if (next_node != parent[cur]){
                    temp[w] += 1;
                    for (int i = 1; i <= 26; i++)
                        count[next_node][i] = temp[i];
                    
                    parent[next_node] = cur;
                    level[next_node] = level[cur] + 1;
                    
                    dfs(next_node);
                    temp[w] -= 1; // 恢复  因为 每个 结点 到 根结点的情况 不一样
                }
            }
        };

        dfs(0);  // 根据 结点编号 递归 依次处理 
        
// 获取 LCA 模块   倍增思路
        // Step1: 先获取 ancestors[i][j]  结点 i 的 第 2^j 个 祖先  比如 找 第 k = 9 的祖先时, k = 2^3 + 2^1  这样
        int m = 32 - __builtin_clz(n); // n 的二进制 长度
        vector<vector<int>> ancestors(n, vector<int>(m, -1)); // ancestors[i][j]:  结点 i 的 第 2^j 个 祖先  父结点 为 第 1 个, 爷爷为 第 2 个
        for (int i = 0; i < n; ++i){
            ancestors[i][0] = parent[i]; // 2^0 = 1
        }
        for (int j = 1; j < m; ++j){
            for (int i = 1; i < n; ++i){
                if (ancestors[i][j-1] != -1){
                    ancestors[i][j] = ancestors[ancestors[i][j-1]][j-1];
                }
            }
        }
        //  倍增 跳转
        function<int(int, int)> getKthAncestor = [&](int node, int k){
            for (int j = 0; j < m; ++j){
                if ((k >> j) & 1){ // 第 j 位 非0, 直接 跳转
                    node = ancestors[node][j];
                    // if (node == -1){  // 在 调用的时候,限制了 j 的距离,必定存在,所以这里可去掉
                    //     return -1;
                    // }
                }
            }
            return node;
        };
        function<int(int, int)> getLCA = [&](int p, int q){            
            if (level[p] < level[q]){
                q = getKthAncestor(q, level[q] - level[p]);  // 更深的上移
            }else if (level[p] > level[q]){
                p = getKthAncestor(p, level[p] - level[q]);
            }
            
            int left = 0, right = level[p];
            while (left <= right){
                int mid = left + (right - left) / 2;
                if (getKthAncestor(p, mid) == getKthAncestor(q, mid))
                    right = mid - 1;
                else 
                    left = mid + 1;
            }
            return getKthAncestor(p, left);           
        };
        
        
//   处理 咨询  并 返回 结果
        vector<int> res;        
        for (auto query: queries)
        {
            int a = query[0], b = query[1];
            int lca = getLCA(a, b);

            // 计算 改成 一样需要的操作数   把 剩下的都改成 数量最多的那个数  边个数 - 阈值一样的最多边数
            int total = 0; //  总的边数   因为 可以修改成任意值, 所以 每条边 最多需要修改 1 次
            int mx = 0;  //  阈值一样 的边数 最大
            for (int i = 1; i <= 26; i++){
                int temp = count[a][i] + count[b][i] - 2 * count[lca][i]; // 权重为 i 的边数
                total += temp;
                mx = max(mx, temp);
            }            
            res.push_back(total - mx);
        }        
        return res;
    }
      
};
class Solution:
    def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        g = [[] for _ in range(n)]
        for u, v, w in edges:
            g[u].append((v, w))
            g[v].append((u, w))

## DFS 处理 部分

        temp = [0] * 27  # 权重 为 i 的边数为 temp[i]
        parent = [-1] * n  #结点 i 的 父母结点 为 parent[i]
        level = [0] * n   # 记录 结点 i 所在 的层编号

        
        cnt = [[0] * 27 for _ in range(n)]  # count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数 
        def dfs(cur):
            for next_node, w in g[cur]:
                if next_node != parent[cur]:
                    temp[w] += 1
                    for i in range(1, 27):
                        cnt[next_node][i] = temp[i]
                
                    parent[next_node] = cur
                    level[next_node] = level[cur] + 1

                    dfs(next_node)
                    temp[w] -= 1

        dfs(0)

##  LCA  倍增 
        m = n.bit_length()
        ancestors = [[-1] * m for _ in range(n)] # ancestors[i][j]:  结点 i 的 第 2^j 个 祖先   每个 k 都可以 分解成 多个 2^i 的 和
        for i in range(n):
            ancestors[i][0] = parent[i]

        for j in range(1, m):
            for i in range(1, n): # 可以 跳过 根结点
                if ancestors[i][j - 1] != -1:
                    ancestors[i][j] = ancestors[ancestors[i][j-1]][j-1]
        
        def getKthAncestor(node, k):
            for j in range(m):
                if (k >> j) & 1:
                    node = ancestors[node][j]
            return node 

        def getLCA(a, b):
            if (level[a] < level[b]):
                b = getKthAncestor(b, level[b] - level[a])
            elif level[a] > level[b]:
                a = getKthAncestor(a, level[a] - level[b])
            
            left, right = 0, level[a]
            while left <= right:
                mid = left + (right - left) // 2
                if getKthAncestor(a, mid) == getKthAncestor(b, mid):
                    right = mid - 1
                else:
                    left = mid + 1
            return getKthAncestor(a, left)


        # 处理 咨询 并返回 结果

        res = []
        for a, b in queries:
            lca_node = getLCA(a, b)
            total = 0
            mx = 0
            for i in range(1, 27):
                temp = cnt[a][i] + cnt[b][i] - cnt[lca_node][i] * 2
                total += temp 
                mx = max(mx, temp)

            res.append(total - mx)

        return res 

方法三: Tarjan(塔杨)算法 【离线】

image.png

B 站链接

image.png

class Solution {
public:
    vector<int> minOperationsQueries(int n, vector<vector<int>>& edges, vector<vector<int>>& queries) 
    {
        vector<vector<pair<int, int>>> g(n);
        for (auto& edge: edges){
            int a = edge[0], b = edge[1], w = edge[2];
            g[a].emplace_back(b, w);
            g[b].emplace_back(a, w);
        }
        
//  DFS 处理, 求取 根结点 到 结点node 的 count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数 
        
        // 一些 辅助
        vector<int> temp(27); // 中间辅助列表 记录  权重为 i 的边数 为 temp[i]
        vector<int> parent(n, -1); // 结点 编号 为 i 的 父母节点为 parent[i]

        vector<vector<int>> count(n, vector<int>(27, 0));  // count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数
        
        function<void(int)> dfs = [&](int cur) {//  count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数 
            for (auto [next_node, w]: g[cur]){// {next_node, weight}
                if (next_node != parent[cur]){
                    temp[w] += 1;
                    for (int i = 1; i <= 26; i++)
                        count[next_node][i] = temp[i];
                    
                    parent[next_node] = cur;
                    
                    dfs(next_node);
                    temp[w] -= 1; // 恢复  因为 每个 结点 到 根结点的情况 不一样
                }
            }
        };

        dfs(0);  // 根据 结点编号 递归 依次处理 
        
// 获取 LCA 部分   Tarjan
        vector<vector<pair<int, int>>> query(n); // u [v, i]
        int m = queries.size();
        for (int i = 0; i < m; ++i){
            int u = queries[i][0], v = queries[i][1];
            query[u].emplace_back(v, i);
            query[v].emplace_back(u, i);
        }
        vector<int> pa(n);
        for (int i = 0; i < n; ++i){
            pa[i] = i; // 指向自己      结点 i 指向 结点 pa[i]
        }

        // 压缩 模块
        function<int(int)> find = [&](int u){
            if (pa[u] == u){
                return u;
            }
            return pa[u] = find(pa[u]);
        };

        // 

        vector<int> LCA(m); // 第 i 组 查询的 lca结点编号为 LCA[i]
        vector<int> visited(n);
        function<void(int)> tarjan = [&](int u){
            visited[u] = true;  // 标记
            for (auto [v, _] : g[u]){
                if (!visited[v]){
                    tarjan(v);
                    pa[v] = u; // 回 u 时, v 指向 u   指向 父母结点 
                }                
            }
            // 离开 u时,枚举 LCA 记录
            for (auto [v, i] : query[u]){
                if (visited[v]){
                    LCA[i] = find(v);
                }
            }            
        };

        tarjan(0);  // 记得 调用!!

        
//   处理 咨询  并 返回 结果
        vector<int> res;        
        for (int i = 0; i < m; ++i){
            int a = queries[i][0], b = queries[i][1];
            int lca = LCA[i]; // 

            // 计算 改成 一样需要的操作数   把 剩下的都改成 数量最多的那个数  边个数 - 阈值一样的最多边数
            int total = 0; //  总的边数   因为 可以修改成任意值, 所以 每条边 最多需要修改 1 次
            int mx = 0;  //  阈值一样 的边数 最大
            for (int i = 1; i <= 26; i++){
                int temp = count[a][i] + count[b][i] - 2 * count[lca][i]; // 权重为 i 的边数
                total += temp;
                mx = max(mx, temp);
            }            
            res.push_back(total - mx);
        }        
        return res;
    }
      
};
class Solution:
    def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        g = [[] for _ in range(n)]
        for u, v, w in edges:
            g[u].append((v, w))
            g[v].append((u, w))

## DFS 处理 部分

        temp = [0] * 27  # 权重 为 i 的边数为 temp[i]
        parent = [-1] * n  #结点 i 的 父母结点 为 parent[i]
        
        cnt = [[0] * 27 for _ in range(n)]  # count[node][w_i]: 从 根结点 到 结点 node 的 路径上 权重 为 w_i 的边数 
        def dfs(cur):
            for next_node, w in g[cur]:
                if next_node != parent[cur]:
                    temp[w] += 1
                    for i in range(1, 27):
                        cnt[next_node][i] = temp[i]
                
                    parent[next_node] = cur

                    dfs(next_node)
                    temp[w] -= 1

        dfs(0)

##  LCA  Tarjan
        query = [[] for _ in range(n)]
        m = len(queries)
        for i in range(m):
            u, v = queries[i][0], queries[i][1]
            query[u].append((v, i))
            query[v].append((u, i))

        pa = [i for i in range(n)]  # 结点 i 指向 结点pa[i] 。 初始为指向自身
        
        # 指向 压缩
        def find(x):
            if x != pa[x]:
                pa[x] = find(pa[x])
            return pa[x]

        LCA = [0] * m  # 记录 第 i 组 查询的 lca 结点 编号
        visited = [False] * n 
        def tarjan(u):
            visited[u] = True # 标记 
            for v, w in g[u]:
                if not visited[v]:
                    tarjan(v)
                    pa[v] = u  #  回 u 时, 让 v 指向 u 
                
            #   离开 u 了, 查看 之前 是否 查询过
            for v, i in query[u]:
                if visited[v]:
                    LCA[i] = find(v)

        tarjan(0)   #  记得 调用!!!


# 处理 咨询 并返回 结果
        res = []
        for i in range(m):
            a, b = queries[i][0], queries[i][1]
            lca_node = LCA[i]
            total = 0
            mx = 0
            for i in range(1, 27):
                temp = cnt[a][i] + cnt[b][i] - cnt[lca_node][i] * 2
                total += temp 
                mx = max(mx, temp)

            res.append(total - mx)

        return res 



1483. 树节点的第 K 个祖先

image.png

class TreeAncestor:

    def __init__(self, n: int, parent: List[int]):
        self.m = n.bit_length() # n 的二进制 长度
        self.ancestors = [[-1] * self.m for _ in range(n)] # 
        for i in range(n):
            self.ancestors[i][0] = parent[i]
        
        for j in range(1, self.m):
            for i in range(n):
                if self.ancestors[i][j - 1] != -1:
                    self.ancestors[i][j] = self.ancestors[self.ancestors[i][j - 1]][j - 1]

    def getKthAncestor(self, node: int, k: int) -> int:
        for j in range(self.m):
            if (k >> j) & 1: # 不断检查 爷爷 结点
                node = self.ancestors[node][j]
                if node == -1:
                    return -1
        return node



# Your TreeAncestor object will be instantiated and called as such:
# obj = TreeAncestor(n, parent)
# param_1 = obj.getKthAncestor(node,k)

# ancestors[i][j]: 节点 i 的 第 2^j 个 祖先  
# 将 k 用 二进制表示。 若  第 j 位 为 1, 结点 node 转移到  ancestors[node][j]
# 状态转移 方程: ancestors[i][j] = ancestors[ancestors[i][j -1]][j - 1]
# ancestors[i][0] = parent[i]


image.png

class TreeAncestor {
public:
    constexpr static int m = 16; // 本题 的 n 在 2**16 之内 
    vector<vector<int>> ancestors;

    TreeAncestor(int n, vector<int>& parent) {
        ancestors = vector<vector<int>> (n, vector<int>(m, -1));
        for (int i = 0; i < n; ++i){
            ancestors[i][0] = parent[i];
        }
        for (int j = 1; j < m; ++j){
            for (int i = 0; i < n; ++i){
                if (ancestors[i][j - 1] != -1){
                    ancestors[i][j] = ancestors[ancestors[i][j - 1]][j - 1];
                }
            }
        }
    }
    
    int getKthAncestor(int node, int k) {
        for (int j = 0; j < m; ++j){
            if ((k >> j) & 1){
                node = ancestors[node][j];
                if (node == -1){
                    return -1;
                }
            }
        }
        return node;
    }
};

/**
 * Your TreeAncestor object will be instantiated and called as such:
 * TreeAncestor* obj = new TreeAncestor(n, parent);
 * int param_1 = obj->getKthAncestor(node,k);
 */

1104. 二叉树寻路 O(loglabel)O(1)\lgroup O(\log label)、O(1) \rgroup

image.png

image.png

image.png

class Solution:
    def pathInZigZagTree(self, label: int) -> List[int]:
        #  根结点 到 label 经过的 结点
        # 之 体现在 结点序号。 
        res = []
        while label != 1:
            res.append(label)
            label >>= 1
        res.append(1)
        res.sort()
        # 修改
        for i in range(len(res)-2, -1, -2): # 最后 一个为 label 必定 不能改
            res[i] = 2**(i + 1) + 2**i - 1 - res[i]
        return res 

class Solution:
    def pathInZigZagTree(self, label: int) -> List[int]:
        #  根结点 到 label 经过的 结点
        # 之 体现在 结点序号。 
        res = []
        while label != 1:
            res.append(label)
            label >>= 1
            label = label ^ (1 << (label.bit_length() - 1)) - 1
        return [1] + res[::-1]

class Solution {
public:
    vector<int> pathInZigZagTree(int label) {
        vector<int> res;
        while (label != 1){
            res.push_back(label);
            label = label / 2; 
        }
        res.push_back(1);
        sort(res.begin(), res.end());

        for (int i = res.size() - 2; i >= 0; i = i - 2 ){
            res[i] = (1 << (i + 1)) + (1 << i) - 1 - res[i];
        }
        return res;
    }
};

LCP 08. 剧情触发时间

image.png

class Solution:
    def getTriggerTime(self, increase: List[List[int]], requirements: List[List[int]]) -> List[int]:
        C, R, H = [0], [0], [0]  # 结果 从 1 开始
        for c, r, h in increase:
            C.append(C[-1] + c)
            R.append(R[-1] + r)
            H.append(H[-1] + h)

        res = []
        for rc, rr, rh in requirements:
            resC = bisect.bisect_left(C, rc)  # bisect.bisect_left返回大于等于x的第一个下标
            resR = bisect.bisect_left(R, rr)
            resH = bisect.bisect_left(H, rh)
            r = max(resC, resH, resR)
            if r == len(increase) + 1: ###!!
                r = -1
            res.append(r)
        return res 
class Solution {
public:
    vector<int> getTriggerTime(vector<vector<int>>& increase, vector<vector<int>>& requirements) {
        // std::lower_bound算法用于在有序数组中寻找第一个大于或等于给定值的元素的位置。
        vector<int> C, R, H;
        C.push_back(0);
        R.push_back(0);
        H.push_back(0);
        for (auto i : increase){
            C.push_back(C.back() + i[0]);
            R.push_back(R.back() + i[1]);
            H.push_back(H.back() + i[2]);
        }
        vector<int> res;
        for (auto re : requirements){
            int idx_c = lower_bound(C.begin(), C.end(), re[0]) - C.begin();
            int idx_r = lower_bound(R.begin(), R.end(), re[1]) - R.begin();
            int idx_h = lower_bound(H.begin(), H.end(), re[2]) - H.begin();
            int r = max(max(idx_c, idx_h), idx_r);
            if (r == increase.size() + 1){
                r = -1;
            }
            res.push_back(r);

        }
        return res;
    }
};

image.png