LCA 求法

130 阅读2分钟

倍增法

#include<bits/stdc++.h>
using namespace std;

// solution of lca:
// 1. up-mark
//
// 2. binary lifting
// f(i, j) logs jump 2^j step from node i
// so f(i, j) = f(f(i, j - 1), j - 1)
//
// depth(i) logs the depth of node i (1 indexed)
//
// for node i and j
// firstly we should move them to a same depth
// then jump up together util they are
// under their lca
//
// 3. tarjan (offline method)
//

const int MAX_N = 40010;

int N, M;

int root;
int h[MAX_N], to[MAX_N * 2], nx[MAX_N * 2], id;

int f[MAX_N][16], depth[MAX_N], vst[MAX_N];

void init () {
    memset(h, -1, sizeof h);
    depth[0] = 0; // flag 
}

void add (int u, int v) {
    to[id] = v, nx[id] = h[u], h[u] = id++;
}

void bfs () {
    queue<int> q;

    q.push(root);
    vst[root] = 1;
    depth[root] = 1;

    while (q.size()) {
        int u = q.front();
        q.pop();

        for (int i = h[u]; i != -1; i = nx[i]) {
            int v = to[i];

            if (vst[v]) continue;

            f[v][0] = u;

            for (int k = 1; k <= 15; k++) {
                // f[v][k - 1] must be calculated already
                // and the f value of v's parent must also be calculated
                f[v][k] = f[f[v][k - 1]][k - 1];
            }

            q.push(v);
            vst[v] = 1;
            depth[v] = depth[u] + 1;
        }
    }
}

int lca(int a, int b) {
    // binary lifting
    // step 1: move a to the same depth of b
    if (depth[a] < depth[b]) swap(a, b);
    for (int k = 15; k >= 0; k--) {
        int fa = f[a][k];
        if (depth[fa] >= depth[b]) a = fa;
    }
    if (a == b) return a;
    // step 2: move util they are under their lca
    // before the for loop below, a must not equal to b
    // so we update a and b when they are not equal
    for (int k = 15; k >= 0; k--) {
        int fa = f[a][k], fb = f[b][k];
        if (fa != fb) {
            a = fa;
            b = fb;
        }
    }
    // jump one more step to their lca
    return f[a][0];
}

int main () {
    init();

    cin >> N;
    for (int i = 0; i < N; i++) {
        int a, b; cin >> a >> b;
        if (b == -1) root = a;
        else {
            add(a, b);
            add(b, a);
        }
    }

    bfs();

    cin >> M;
    while (M--) {
        int a, b; cin >> a >> b;
        int ans = lca(a, b);
        printf("the lca of %d and %d is %d\n", a, b, ans);
    }

    return 0;
}

Tarjan 离线

#include<bits/stdc++.h>
using namespace std;

// divide the vertex into three category
// 1. visited
// 2. visiting
// 3. not visit
enum class Status {
    NOT_VISIT,
    VISITING,
    VISITED,
};

const int MAX_N = 10010, MAX_M = 20010;

int N, M;
// graph
int h[MAX_N], to[MAX_N * 2], wt[MAX_N * 2], nx[MAX_N * 2], id;
// union-find-set
int fa[MAX_N];
Status st[MAX_N];
vector<pair<int, int>> query[MAX_N];
pair<int, int> src_query[MAX_M];
int lca_list[MAX_M];

void add (int u, int v, int w) {
    to[id]= v, wt[id] = w, nx[id] = h[u], h[u] = id++;
}

int fnd (int x) {
    if (fa[x] != x) fa[x] = fnd(fa[x]);
    return fa[x];
}

// for each visiting vertex, check which query is related to it
// if the other vertex in this query is visited, then the lca is
// the father of the other vertex in union-find-set
//
// The key concept of tarjan algorithm is to group 
// the visiting vertex and it's visited subtree as an union-find-set
void tarjan (int u) {
    st[u] = Status::VISITING;

    for (int i = h[u]; ~i; i = nx[i]) {
        int v = to[i];
        if (st[v] == Status::NOT_VISIT) {
            tarjan(v);
            // key step !!!
            // must be set after tarjan(v), cannot be set before tarjan(v).
            // Otherwise, consider situation like 'a'->'b'->'c'->'d'
            // when 'd' is already visited, and trace back to visiting 'c',
            // and a query is about 'c'(visiting) and 'd'(visited),
            // this will cause to find 'a' as their lca,
            // which is not correct
            // Thus !!! the father of vertex v should not be set to u
            // util vertex v is visited.
            fa[v] = u;
        }
    }

    for (auto p: query[u]) {
        int v = p.first, query_id = p.second;
        if (st[v] == Status::VISITED) {
            // int lca = fnd(v);
            // lca_list[query_id] = u == v ? u : lca;

            // NOTE: It's impossible to have u == v and st[u] == Status::VISITED
            // the lca of u and v will not be calculated here
            // Thus, set the lca for u and v outside of tarjan algorithm
            lca_list[query_id] = fnd(v);
        }
    }

    st[u] = Status::VISITED;
}

void init () {
    memset(h, -1, sizeof h);
    for (int i = 0; i < MAX_N; i++) fa[i] = i;
}

int main () {
    init();

    cin >> N >> M;
    for (int i = 0; i < N - 1; i++) {
        int u, v, w; cin >> u >> v >> w;
        add(u, v, w);
        add(v, u, w);
    }

    for (int i = 0; i < M; i++) {
        int x, y; cin >> x >> y;
        query[x].push_back({ y, i });
        query[y].push_back({ x, i });
        src_query[i] = {x, y};
        if (x == y) lca_list[i] = x;
    }

    tarjan(1);

    for (int i = 0; i < M; i++) {
        auto q = src_query[i];
        int x = q.first, y = q.second;
        printf("the lca of %d and %d is %d\n", x, y, lca_list[i]);
    }

    return 0;
}