luogu P6329 【模板】点分树 | 震波

https://www.luogu.com.cn/problem/P6329

先建一颗点分树
考虑树上的每个节点维护什么
因为树高是log的,所以怎么暴力怎么维护就好了

维护两个前缀和, u u u的子树的点距离 u u u点的距离为 i i i的个数
以及对于 f a [ u ] fa[u] fa[u]
然后一路往上跳,每次加上fa[u]的减去u的即可
相当于市点分治的减去重复计算的部分
空间处理用vector即可
code:

#include<bits/stdc++.h>
#define N 500050
using namespace std;
struct edge {
    int v, nxt;
} e[N << 1];
int p[N], eid;
void init() {
    memset(p, - 1, sizeof p);
    eid = 0;
}
void insert(int u, int v) {
    e[eid].v = v;
    e[eid].nxt = p[u];
    p[u] = eid ++;
}
int dep[N], fa[N][20];
int LCA(int x, int y) {
    if(dep[x] < dep[y]) swap(x, y);
    for(int i = 18; i >= 0; i --) if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
    if(x == y) return x;
    for(int i = 18; i >= 0; i --) if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}
int dis(int x, int y) {
    return dep[x] + dep[y] - 2 * dep[LCA(x, y)];
}
void dfsp(int u) {
    dep[u] = dep[fa[u][0]] + 1;
    for(int i = p[u]; i + 1; i = e[i].nxt) {
        int v = e[i].v;
        if(v == fa[u][0]) continue;
        fa[v][0] = u; dfsp(v);
    }
}
vector<int> t[2][N];
int size[N], msize[N], vis[N], FA[N];
int sz, mx, rt;
void dfs(int u, int ff) {
    size[u] = 1; msize[u] = 0;
    for(int i = p[u]; i + 1; i = e[i].nxt) {
        int v = e[i].v;
        if(v == ff || vis[v]) continue;
        dfs(v, u); size[u] += size[v];
        msize[u] = max(msize[u], size[v]);
    }
    msize[u] = max(msize[u], sz - size[u]);
    if(msize[u] < mx) mx = msize[u], rt = u;
}
void solve(int u, int ff, int n) {
    mx = sz = n, rt = 0;
    dfs(u, u); u = rt;
    FA[u] = ff; size[u] = sz;
    t[0][u].resize(sz + 5), t[1][u].resize(sz + 5);
    vis[u] = 1;
    for(int i = p[u]; i + 1; i = e[i].nxt) {
        int v = e[i].v;
        if(vis[v]) continue;
        solve(v, u, size[v]);
    }
}
#define lowbit(x) (x & -x)
void update(int o, int u, int x, int y) { ++ x;
    for(; x <= size[u] + 1; x += lowbit(x)) t[o][u][x] += y;
}
int query(int o, int u, int x) { x = min(x + 1, size[u] + 1);
    int ret = 0;
    for(; x; x -= lowbit(x)) ret += t[o][u][x];
    return ret;
}
void Add(int x, int y) {
    for(int i = x; i; i = FA[i]) update(0, i, dis(i, x), y);
    for(int i = x; FA[i]; i = FA[i]) update(1, i, dis(FA[i], x), y);
}
int n, m, val[N];
int main() {
    init();
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i ++) scanf("%d", &val[i]);
    for(int i = 1; i < n; i ++) {
        int u, v;
        scanf("%d%d", &u, &v);
        insert(u, v), insert(v, u);
    }
    dfsp(1);
    for(int j = 1; j <= 18; j ++)
        for(int i = 1; i <= n; i ++) fa[i][j] = fa[fa[i][j - 1]][j - 1];
    solve(1, 0, n);
    for(int i = 1; i <= n; i ++) Add(i, val[i]);
    int lst = 0;
    while(m --) {
        int o, x, y;
        scanf("%d%d%d", &o, &x, &y);
        x ^= lst, y ^= lst;
        if(!o) {
            lst =  query(0, x, y);
            for(int i = x; FA[i]; i = FA[i]) {
                int d = dis(x, FA[i]);
                if(d <= y) lst += query(0, FA[i], y - d) - query(1, i, y - d);
            } printf("%d\n", lst);
        } else Add(x, y - val[x]), val[x] = y;
    }
    return 0;
}
posted @ 2021-08-05 16:59  lahlah  阅读(37)  评论(0编辑  收藏  举报