ZJOI2008, 树的统计

66 阅读3分钟

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作:

  1. CHANGE u t,把结点u的权值改为t。
  2. QMAX u v,询问从点u到点v的路径上的节点的最大权值
  3. QSUM u v, 询问从点u到点v的路径上的节点的权值和。注意:从点u到点v的路径上的节点包括u和v本身。

输入格式

输入的第一行为一个整数n,表示节点的个数。接下来n–1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。

接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来一行,为一个整数q,表示操作的总数。

接下来q行,每行一个操作,以CHANGE u t或者QMAX u v或者QSUM u v的形式给出。

保证1≤n≤3×104,0≤q≤2×104,中途操作中保证每个节点的权值w在−3×104到3×104之间。

输出格式

对于每个QMAX或者QSUM的操作,每行输出一个整数表示要求输出的结果。

样例输入

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

样例输出

4
1
2
2
10
6
5
6
5
16

思路:我们把给定图放在数组中的规律为,找到该图的dfs序(即,dfs的顺序)我们先dfs一遍,找到重儿子,再次dfs对其编号,放在数组中,保证优先放入重儿子,如此,重链的顺序是连续的,按重链找出对应的dfs序,根据dfs序用线段树维护即可

代码

void AddEdge(int u,int v)
{
    to[++cnt]=v;
    nxt[cnt]=fst[u];
    fst[u]=cnt;
}
void dfs1(int u,int v)
{
    sz[u] = 1;
    hs[u] = -1;
    dep[u] = dep[v] + 1;
    fa[u] = v;
    for(int i=fst[u];i;i=nxt[i])
    {
        int c = to[i];
        if(c == v)
        continue;
        dfs1(c,u);
        sz[u] += sz[c];
        if(hs[u] == -1 || sz[c] > sz[hs[u]])
        hs[u] = c;
    }
}

void dfs2(int u,int v)
{
    top[u] = v;
    l1[u] = ++tot;
    id[tot] = u;
    if(hs[u] != -1)
    dfs2(hs[u],v);
    for(int i=fst[u];i;i=nxt[i])
    {
        int c = to[i];
        if(hs[u]!=c&&fa[u]!=c)
        dfs2(c,c);
    }
}

struct Node {
    ll mx;
    ll sum;
} seg[N * 4];

inline void update(int idx) {
    seg[idx].sum = (seg[idx * 2].sum + seg[idx * 2 + 1].sum );
    seg[idx].mx = max(seg[idx*2].mx,seg[idx*2+1].mx);
}

inline void build(int idx, int l, int r) {
    if (l == r)
    {
        seg[idx].mx = a[id[l]];
        seg[idx].sum = a[id[l]];
    }
    else {
        int mid = (r + l) / 2;
        build(idx * 2, l, mid);
        build(idx * 2 + 1, mid + 1, r);
        update(idx);
    }
}

inline void modify(int idx, int l, int r, int pos,int val) {
    if (l == r) {
        seg[idx].mx = val;
        seg[idx].sum = val;
        return;
    }
    int mid = (l + r) / 2;
    if (pos <= mid)
        modify(idx * 2, l, mid,pos,val);
    else
        modify(idx * 2 + 1, mid + 1, r,pos,val);
    update(idx);
}

ll query(int idx, int l, int r, int ql, int qr) {
    if (l == ql && r == qr)
        return seg[idx].sum;
    int mid = (r + l) / 2;
    if (qr <= mid)
        return query(idx * 2, l, mid, ql, qr);
    else if (ql > mid)
        return query(idx * 2 + 1, mid + 1, r, ql, qr);
    else
        return (query(idx * 2, l, mid, ql, mid) +
    query(idx * 2 + 1, mid + 1, r, mid + 1, qr));
}

ll querym(int idx, int l, int r, int ql, int qr) {
    if (l == ql && r == qr)
        return seg[idx].mx;
    int mid = (r + l) / 2;
    if (qr <= mid)
        return querym(idx * 2, l, mid, ql, qr);
    else if (ql > mid)
        return querym(idx * 2 + 1, mid + 1, r, ql, qr);
    else
        return max(querym(idx * 2, l, mid, ql, mid),
    querym(idx * 2 + 1, mid + 1, r, mid + 1, qr));
}

ll querysum(int u,int v)
{
    ll ans = 0;
    while(top[u]!=top[v])
    {
        if(dep[top[u]] < dep[top[v]])
        swap(u,v);
        ans += query(1,1,n,l1[top[u]],l1[u]);
        u = fa[top[u]];
    }
    if(dep[u] < dep[v])
        swap(u,v);
    ans += query(1,1,n,l1[v],l1[u]);
    return ans;
}

ll querymx(int u,int v)
{
    ll ans = -1e18;
    while(top[u]!=top[v])
    {
        if(dep[top[u]] < dep[top[v]])
        swap(u,v);
        ans = max(ans,querym(1,1,n,l1[top[u]],l1[u]));
        u = fa[top[u]];
    }
    if(dep[u] < dep[v])
        swap(u,v);
    ans = max(ans,querym(1,1,n,l1[v],l1[u]));
    return ans;
}
void solve()
{
    cin >> n;
    for(int i = 1;i < n;i ++)
    {
        int x,y;
        cin >> x >> y;
        AddEdge(x,y);
        AddEdge(y,x);
    }
    for(int i = 1;i <= n;i ++)
        cin >> a[i];
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    cin >> m;
    for(int i = 1;i <= m;i ++)
    {
        string s;
        int x,y;
        cin >> s;
        cin >> x >> y;
        if(s == "CHANGE")
        modify(1,1,n,l1[x],y);
        else if(s == "QSUM"){
        cout << querysum(x,y) << '\n';
    }
    else
        cout << querymx(x,y) << '\n';
    }
}
int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);
    int tt = 1;
    while(tt--)
    {
        solve();
    }
}