luogu P6329 【模板】点分树 | 震波
https://www.luogu.com.cn/problem/P6329
先建一颗点分树
考虑树上的每个节点维护什么
因为树高是log的,所以怎么暴力怎么维护就好了
维护两个前缀和,
u
u
u的子树的点距离
u
u
u点的距离为
i
i
i的个数
以及对于
f
a
[
u
]
fa[u]
fa[u]的
然后一路往上跳,每次加上fa[u]的减去u的即可
相当于市点分治的减去重复计算的部分
空间处理用vector即可
code:
#include<bits/stdc++.h>
#define N 500050
using namespace std;
struct edge {
int v, nxt;
} e[N << 1];
int p[N], eid;
void init() {
memset(p, - 1, sizeof p);
eid = 0;
}
void insert(int u, int v) {
e[eid].v = v;
e[eid].nxt = p[u];
p[u] = eid ++;
}
int dep[N], fa[N][20];
int LCA(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = 18; i >= 0; i --) if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
if(x == y) return x;
for(int i = 18; i >= 0; i --) if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
int dis(int x, int y) {
return dep[x] + dep[y] - 2 * dep[LCA(x, y)];
}
void dfsp(int u) {
dep[u] = dep[fa[u][0]] + 1;
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v;
if(v == fa[u][0]) continue;
fa[v][0] = u; dfsp(v);
}
}
vector<int> t[2][N];
int size[N], msize[N], vis[N], FA[N];
int sz, mx, rt;
void dfs(int u, int ff) {
size[u] = 1; msize[u] = 0;
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v;
if(v == ff || vis[v]) continue;
dfs(v, u); size[u] += size[v];
msize[u] = max(msize[u], size[v]);
}
msize[u] = max(msize[u], sz - size[u]);
if(msize[u] < mx) mx = msize[u], rt = u;
}
void solve(int u, int ff, int n) {
mx = sz = n, rt = 0;
dfs(u, u); u = rt;
FA[u] = ff; size[u] = sz;
t[0][u].resize(sz + 5), t[1][u].resize(sz + 5);
vis[u] = 1;
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v;
if(vis[v]) continue;
solve(v, u, size[v]);
}
}
#define lowbit(x) (x & -x)
void update(int o, int u, int x, int y) { ++ x;
for(; x <= size[u] + 1; x += lowbit(x)) t[o][u][x] += y;
}
int query(int o, int u, int x) { x = min(x + 1, size[u] + 1);
int ret = 0;
for(; x; x -= lowbit(x)) ret += t[o][u][x];
return ret;
}
void Add(int x, int y) {
for(int i = x; i; i = FA[i]) update(0, i, dis(i, x), y);
for(int i = x; FA[i]; i = FA[i]) update(1, i, dis(FA[i], x), y);
}
int n, m, val[N];
int main() {
init();
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++) scanf("%d", &val[i]);
for(int i = 1; i < n; i ++) {
int u, v;
scanf("%d%d", &u, &v);
insert(u, v), insert(v, u);
}
dfsp(1);
for(int j = 1; j <= 18; j ++)
for(int i = 1; i <= n; i ++) fa[i][j] = fa[fa[i][j - 1]][j - 1];
solve(1, 0, n);
for(int i = 1; i <= n; i ++) Add(i, val[i]);
int lst = 0;
while(m --) {
int o, x, y;
scanf("%d%d%d", &o, &x, &y);
x ^= lst, y ^= lst;
if(!o) {
lst = query(0, x, y);
for(int i = x; FA[i]; i = FA[i]) {
int d = dis(x, FA[i]);
if(d <= y) lst += query(0, FA[i], y - d) - query(1, i, y - d);
} printf("%d\n", lst);
} else Add(x, y - val[x]), val[x] = y;
}
return 0;
}