树链剖分
树链剖分,简称树剖,就是把一颗又大又高的树拆成一些链,方便使用某些数据结构。
一般树剖
我们随便 DFS
一下,将整棵树分成一些链,其中里面的 DFS
序连续。
链的数量不管怎样是固定的 \(O(N)\)。
hack:
某种 DFS
序是 \((1,3,2,5,4,7,6,9,8,11,10)\),只要你不走运刚好,就仍然可以把单次询问卡到 \(O(N)\)。
所以我们要把这坨东西优化一下。
重链优先
DFS
时优先用子树大小大的,剩下的随便。
链的数量没改进,但是单次询问的复杂度降低到 \(O(\log N)\)(因为只会有 \(O(\log N)\) 个轻边)。
证明:
考虑从某个节点往根节点走,路上遇到的边全是轻边。
很显然往上面每多一个轻边节点数量就会多一倍。
然后你就可以用这种方式把树打成一条链然后在链上写东西了(说实话你可以把树打成链后在链上跑莫队,但是这既没有必要也不是很好)。
同时我们注意到这个玩意还满足普通 DFS
的性质,所以仍然支持子树修改/查询的操作。
这个玩意也可以求 \(\operatorname{LCA}\):
-
如果两个节点在同一个链里面,深度小的就是 \(\operatorname{LCA}\)。
-
否则,第一个节点的链的顶端的深度大,将第一个节点跳至链的顶端的父亲,否则第二个节点跳至链的顶端的父亲,然后返回第一步。
既然能求 \(\operatorname{LCA}\),那么也能处理链上修改/查询。
于是就有了下面这道题。
P3384 【模板】重链剖分/树链剖分
一棵 \(N\) 个节点的树,每一个节点有一个点权 \(V_i\)。
四个操作:
链上点的点权加一个数,
查询链上的点权和,
子树点权加一个数,
查询子树点权和。
直接把树打成一个链,然后链上随便写一个线段树维护。
很难写。
代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
int n, m, r, p, a, b, op, x, y, z;
int d[400040], t[400040], c[100010];
int fa[200020], dfn[200020], ss[200020], dep[200020], top[200020], hv[200020], dcnt, ed[200020];
vector<int> to[200020];
void pushdn(int x, int l, int r) {
int mid = (l + r) >> 1;
d[x * 2] += (mid - l + 1) * t[x];
d[x * 2 + 1] += (r - mid) * t[x];
t[x * 2] += t[x], t[x * 2 + 1] += t[x];
t[x] = 0;
}
int query(int x, int l, int r, int ql, int qr) {
if(r < ql || qr < l) {
return 0;
}
if(ql <= l && r <= qr) {
return d[x];
}
pushdn(x, l, r);
int mid = (l + r) >> 1;
return query(x * 2, l, mid, ql, qr) + query(x * 2 + 1, mid + 1, r, ql, qr);
}
void add(int x, int l, int r, int ql, int qr, int v) {
if(r < ql || qr < l) {
return ;
}
if(ql <= l && r <= qr) {
d[x] += (r - l + 1) * v;
t[x] += v;
return ;
}
pushdn(x, l, r);
int mid = (l + r) >> 1;
add(x * 2, l, mid, ql, qr, v);
add(x * 2 + 1, mid + 1, r, ql, qr, v);
d[x] = d[x * 2] + d[x * 2 + 1];
}
void DFS(int x) {
ss[x] = 1;
dep[x] = dep[fa[x]] + 1;
for(auto i : to[x]) {
if(i != fa[x]) {
fa[i] = x;
DFS(i);
ss[x] += ss[i];
if(ss[i] > ss[hv[x]]) {
hv[x] = i;
}
}
}
}
void HCD(int x) {
dfn[x] = ++dcnt;
ed[x] = dfn[x];
add(1, 1, n, dfn[x], dfn[x], c[x]);
if(hv[x]) {
top[hv[x]] = top[x];
HCD(hv[x]);
ed[x] = ed[hv[x]];
}
for(auto i : to[x]) {
if(i != fa[x] && i != hv[x]) {
top[i] = i;
HCD(i);
ed[x] = ed[i];
}
}
}
int LCA(int x, int y) {
int ans = 0;
for(; top[x] != top[y];) {
if(dep[top[x]] > dep[top[y]]) {
ans += query(1, 1, n, dfn[top[x]], dfn[x]);
x = fa[top[x]];
}else {
ans += query(1, 1, n, dfn[top[y]], dfn[y]);
y = fa[top[y]];
}
}
if(dep[x] > dep[y]) {
swap(x, y);
}
ans += query(1, 1, n, dfn[x], dfn[y]);
cout << ans % p << '\n';
return x;
}
int LCAc(int x, int y, int v) {
for(; top[x] != top[y];) {
if(dep[top[x]] > dep[top[y]]) {
add(1, 1, n, dfn[top[x]], dfn[x], v);
x = fa[top[x]];
}else {
add(1, 1, n, dfn[top[y]], dfn[y], v);
y = fa[top[y]];
}
}
if(dep[x] > dep[y]) {
swap(x, y);
}
add(1, 1, n, dfn[x], dfn[y], v);
return x;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n >> m >> r >> p;
for(int i = 1; i <= n; i++) {
cin >> c[i];
}
for(int i = 1; i < n; i++) {
cin >> a >> b;
to[a].push_back(b);
to[b].push_back(a);
}
DFS(r);
top[1] = 1;
HCD(r);
for(; m--; ) {
cin >> op;
if(op == 1) {
cin >> x >> y >> z;
LCAc(x, y, z);
}else {
if(op == 2) {
cin >> x >> y;
LCA(x, y);
}else {
if(op == 3) {
cin >> x >> y;
add(1, 1, n, dfn[x], ed[x], y);
}else {
cin >> x;
cout << query(1, 1, n, dfn[x], ed[x]) % p << '\n';
}
}
}
}
return 0;
}