ZJOI2008, 树的统计 树链剖分模板

//题意:给定一棵树,现在我需要询问以下操作
//      1.q,u之间的最小值
//      2.q,u之间的简单路径的权值和
//      3.修改树上q点的权值
//思路:如果是在一段序列上的问题,我们可以直接线段树解决,但是这是一棵树,我们也无法将两点之间的简单路径转化为一段连续区间
//      所以我们使用树链剖分(重链剖分),如此一来我们可以用类似于倍增的方式维护出一条简单路径(简单路径上的信息可以用线段树维护)
//        因为我们重链剖分后,按照优先遍历重儿子的理念,一条重链一定在dfs序中是连续的,这也方便了我们用线段树查询
// 
//复杂度:每次维护区间信息 区间剖分logn*线段树维护*logn 所以是O(logn*logn)
//   
#include<bits/stdc++.h>
using namespace std;
#define int long long

const int N = 3 * 1e5;

int n, m, a[N];
vector<int> e[N];
int l[N], r[N], idx[N];
int sz[N], hs[N], tot, top[N], dep[N], fa[N];

struct info {
    int maxv, sum;//要记录的路径最大值,与路径之和
    info(int a=0, int b=0) {
        maxv = a, sum = b;
    }
};

info operator +(const info& l, const info& r) {
    return { max(l.maxv, r.maxv), l.sum + r.sum };//重载合并运算
}

struct node {
    info val;
}seg[N * 4];

void update(int id) {
    seg[id].val = seg[id * 2].val + seg[2 * id + 1].val;
}

void build(int id, int l, int r) {
    if (l == r) {
        //我们是对dfs序建树,所以这里要注意是dfs序中l号点
        seg[id].val = { a[idx[l]],a[idx[l]] };
    }
    else {
        int mid = (l + r) / 2;
        build(id * 2, l, mid);
        build(id * 2 + 1, mid + 1, r);
        update(id);
    }
}

void change(int id, int l, int r, int pos, int val) {//本题只有单点修改
    if (l == r) {
        seg[id].val = { val,val };
    }
    else {
        int mid = (l + r) / 2;
        if (pos <= mid) change(id * 2, l, mid, pos, val);
        else change(id * 2 + 1, mid + 1, r, pos, val);
        update(id);
    }
}

info query(int id, int l, int r, int ql, int qr) {
    if (l == ql && r == qr) return seg[id].val;
    int mid = (l + r) / 2;
    if (qr <= mid) return query(id * 2, l, mid, ql, qr);
    else if (ql > mid) return query(id * 2 + 1, mid + 1, r, ql, qr);
    else {
        return query(id * 2, l, mid, ql, mid) +
            query(id * 2 + 1, mid + 1, r, mid + 1, qr);
    }
}

void dfs1(int u, int f) {//第一遍dfs,求子树大小,重儿子,父亲,深度
    sz[u] = 1;
    hs[u] = -1;
    fa[u] = f;
    dep[u] = dep[f] + 1;
    for (auto v : e[u]) {
        if (v == f) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        if (hs[u] == -1 || sz[v] > sz[hs[u]])
            hs[u] = v;
    }
}

void dfs2(int u, int t) {//第二次dfs,求每个点的dfs序,重链上的链头元素
    top[u] = t;
    l[u] = ++tot;
    idx[tot] = u;
    if (hs[u] != -1) {
        dfs2(hs[u], t);
    }//优先遍历重链,如此一来将重链的dfs序搞成连续的
    for (auto v : e[u]) {
        if (v != fa[u] && v != hs[u])
            dfs2(v, v);//这里很重要,最开始被dls坑了
    }
    r[u] = tot;
}

info query1(int u, int v) {
    info ans{ (int)-1e9,0 };
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) {
            ans = ans + query(1, 1, n, l[top[v]], l[v]);
            v = fa[top[v]];
        }
        else {
            ans = ans + query(1, 1, n, l[top[u]], l[u]);
            u = fa[top[u]];
        }
    } 
    if (dep[u] <= dep[v]) ans = ans + query(1, 1, n, l[u], l[v]);
    else ans = ans + query(1, 1, n, l[v], l[u]);
    return ans;
}

signed main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%lld%lld", &u, &v);
        e[u].push_back(v);
        e[v].push_back(u);
    }
    for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
    dfs1(1, 0);
    dfs2(1, 1);
    build(1, 1, n);
    scanf("%lld", &m);
    for (int i = 0; i < m; i++) {
        int u, v;
        static char op[10];
        scanf("%s%lld%lld", op, &u, &v);
        if (op[0] == 'C') {
            change(1, 1, n, l[u], v);
        }
        else {
            info ans = query1(u, v);
            if (op[1] == 'M') printf("%d\n", ans.maxv);
            else printf("%d\n", ans.sum);
        }
    }
    return 0;
}



/*19
1 2
2 3
3 4
4 5
5 6
5 7
3 10
2 8
8 9
1 11
11 12
12 13
12 14
1 15
15 16
16 17
17 18
16 19*/

 

posted @ 2022-12-30 23:13  Aacaod  阅读(13)  评论(0编辑  收藏  举报