如何写树链剖分
树链剖分模板
树链剖分可以把树上的点划分成一条条连续的“链”,“链”是一条简单路径,上面的每个点满足祖先后代关系,从根节点到每个点都只需经过最多\(\log_2n\)条链。 链上的每一个点 dfs序也是连续的, 故可以配合其他数据结构解决很多树上查询问题(近乎所有静态树问题)
先定义siz[u]
为点u
的子树大小。
int siz[MAXN]; //每个点的子树大小
int dep[MAXN]; //每个点的深度
int mxson[MAXN];//每个点的重儿子
链分为重链和轻链。 考虑每个点,它只会和它的一个儿子组成链,这个儿子叫做 重儿子, 其他儿子叫做 轻儿子。 重儿子满足: siz[mxson[u]]
最大.
第一次dfs:
任务清单:
- 求出每个节点 u 的 siz[u], dep[u], fa[u]
- 求出每个节点的重儿子(注意叶节点没有重儿子!)
void dfs1(int u, int faa){
dep[u] = dep[faa] + 1;
siz[u] = 1;
fa[u] = faa;
int mxsiz = 0;
for(int e = h[u]; e; e = nxt[e]){
int v = ev[e];
if(v == faa) continue;
dfs1(v, u);
if(mxsiz < siz[v]) {
mxson[u] = v;
mxsiz = siz[v];
}
siz[u] += siz[v];
}
}
第二次dfs: 确定dfs序和整棵树的剖分情况.
任务清单:
- 求出每个点的“新编号”, 所有在线段树上的操作都要使用 新编号 (因为只有使用新编号,每条链上的点\(dfs\)序才是连续的。
- 求出每个点
u
的top[u]
(下面有解释)
int top[u]; //u所在的链的顶端
int dfn[MAXN]; //每个点的dfs序(第二次dfs后) 作为每个点的新编号
int timestamp = 0;
void dfs2(int u, int topf){
dfn[u] = ++timestamp;
top[u] = topf;
if(mxson[u] == 0) return; //VERY IMPORTANT!!!!
//上面这句代码不加会RE! (没有儿子的节点不可以dfs!)
//或者加一句if(u == 0) return; 也可以
dfs2(mxson[u], topf);
for(int e = h[u]; e; e = nxt[e]){
int v = ev[e];
if(v != fa[u] && v != mxson[u]) dfs2(v, v);
}
}
两遍\(dfs\)后, 对整棵树的剖分已经完成.
易错点
注意!! 注意!! 注意!!
siz[u], top[u], fa[u], dep[u], mxson[u]
中的 \(u\) 是每个点的 旧编号!!!
树链剖分的应用
举个例子: 通过树链剖分实现链上求和等操作.
链上询问
一边求 \(LCA\) 一边使用线段树修改/求和. (因为每条链上的点的\(dfs\)序连续)
所有链上操作都可以基于该模板。 瞪大眼睛,注释非常重要!
#define swapdep(u, v) if(dep[top[u]] < dep[top[v]]) swap(u, v) //u所在链更深
int qchain(int u, int v) {
ll ret = 0;
//和倍增求LCA一样, 这里求LCA也是让两个节点轮流向上"跳"。
//此时,要求跳的那个点 所在链的深度 较大(因为跳完之后,u = fa[top[u]])
while(top[u] != top[v]) {
swapdep(u, v);
upd(ret, query(1, 1, n, dfn[top[u]], dfn[u]) );
u = fa[top[u]];
}
//两个点在同一条链上以后, 就要根据 点自己的深度 来判断谁先谁后了。(查询[l, r]区间和时, l <= r )
if(dep[u] < dep[v]) swap(u, v);
upd(ret, query(1, 1, n, dfn[v], dfn[u]) );
return ret;
}
链上修改同理.
子树询问
每个点的子树\(dfs\)序连续, 同样使用线段树修改/求和
注意 siz[u] 中的u是每个点的 初始编号
int qsub(int u){
return query(1, 1, n, dfn[u], dfn[u] + siz[u] - 1);
}
void updsub(int u, int val){
update(1, 1, n, dfn[u], dfn[u] + siz[u] - 1, val);
}
CODE (luogu P3384)
#include<bits/stdc++.h>
using namespace std;
#define ll long long
void write(ll x) {
if(x < 0) {
putchar('-');
x = -x;
}
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
ll read(){
ll ret = 0, f = 1; char c = getchar();
for(; !isdigit(c); c = getchar()) if('-'==c) f = -1;
for(; isdigit(c); c = getchar()) ret = ret*10 + c - '0';
return ret * f;
}
#define nl puts("")
#define bs putchar(' ')
const int _maxn = 150005;
int dep[_maxn], siz[_maxn], fa[_maxn], mxson[_maxn];
int dfn[_maxn], top[_maxn];
int n, m, r, p;
int a[_maxn];
int h[_maxn], nxt[_maxn*2], ev[_maxn*2], cnte;
void adde(int u, int v){
cnte++;
ev[cnte] = v;
nxt[cnte] = h[u];
h[u] = cnte;
}
void dfs1(int u, int faa){
dep[u] = dep[faa] + 1;
siz[u] = 1;
fa[u] = faa;
int mxsiz = 0;
for(int e = h[u]; e; e = nxt[e]){
int v = ev[e];
if(v == faa) continue;
dfs1(v, u);
if(mxsiz < siz[v]) {
mxson[u] = v;
mxsiz = siz[v];
}
siz[u] += siz[v];
}
}
int timestamp = 0;
void dfs2(int u, int topf){
dfn[u] = ++timestamp;
top[u] = topf;
if(mxson[u] == 0) return; //VERY IMPORTANT!!!!
dfs2(mxson[u], topf);
for(int e = h[u]; e; e = nxt[e]){
int v = ev[e];
if(v != fa[u] && v != mxson[u]) dfs2(v, v);
}
}
//segtree
#define il inline
ll sum[_maxn*4], tag[_maxn*4];
il int ls(int o) {return o*2;}
il int rs(int o) {return o*2+1;}
void upd(ll &x, ll y) {
x = (p+x+y) % p;
}
void _op(int o, int l, int r, int x){
upd(tag[o], x);
upd(sum[o], (ll)(r-l+1) * x % p);
}
il void pushdown(int o, int l, int r) {
if(tag[o] == 0) return;
int mid = (l+r)>>1;
_op(ls(o), l, mid, tag[o]);
_op(rs(o), mid+1, r, tag[o]);
tag[o] = 0;
}
void update(int o, int l, int r, int ql, int qr, int v) {
pushdown(o, l, r);
if(ql <= l && r <= qr) {
upd(tag[o], v);
upd(sum[o], (ll)(r-l+1)*v%p);
return;
}
int mid = (l+r)>>1;
if(ql <= mid) update(ls(o), l, mid, ql, qr, v);
if(mid < qr) update(rs(o), mid+1, r, ql, qr, v);
sum[o] = (sum[ls(o)] + sum[rs(o)]) % p;
}
ll query(int o, int l, int r, int ql, int qr) {
pushdown(o, l, r);
if(ql <= l && r <= qr) {
return sum[o];
}
int mid = (l+r)>>1;
ll ret = 0;
if(ql <= mid) upd(ret, query(ls(o), l, mid, ql, qr) );
if(mid < qr) upd(ret, query(rs(o), mid+1, r, ql, qr) );
return ret;
}
//segtree ends here.
//下面一句话非常重要!!
#define swapdep(u, v) if(dep[top[u]] < dep[top[v]]) swap(u, v) //u所在链更深
int qchain(int u, int v) {
ll ret = 0;
while(top[u] != top[v]) {
swapdep(u, v);
upd(ret, query(1, 1, n, dfn[top[u]], dfn[u]) );
u = fa[top[u]];
}
if(dep[u] < dep[v]) swap(u, v);
upd(ret, query(1, 1, n, dfn[v], dfn[u]) );
return ret;
}
void updchain(int u, int v, int val) {
while(top[u] != top[v]) {
swapdep(u, v);
update(1, 1, n, dfn[top[u]], dfn[u], val);
u = fa[top[u]];
}
if(dep[u] < dep[v]) swap(u, v);
update(1, 1, n, dfn[v], dfn[u], val);
}
int qsub(int u){
return query(1, 1, n, dfn[u], dfn[u] + siz[u] - 1);
}
void updsub(int u, int val){
update(1, 1, n, dfn[u], dfn[u] + siz[u] - 1, val);
}
signed main(){
#ifndef ONLINE_JUDGE
freopen(".in","r",stdin);
freopen(".out", "w", stdout);
#endif
n = read(), m = read(), r = read(), p = read();
for(int i = 1; i <= n; i++) {
a[i] = read();
}
for(int i = 1; i < n; i++) {
int u = read(), v= read();
adde(u, v); adde(v, u);
}
dfs1(r, 0);
dfs2(r, r);
// for(int i = 1; i <= n; i++) {
// printf("%d mxs=%d fa=%d siz=%d dfn=%d top=%d\n", i, mxson[i], fa[i],siz[i], dfn[i], top[i]);
// }
for(int i = 1; i <= n; i++) {
update(1, 1, n, dfn[i], dfn[i], a[i]);
assert(top[i] != 0);
}
while(m--){
int cas, x, y, z;
cas = read(), x = read();
switch(cas){
case 1:
y = read(), z = read();
updchain(x, y, z);
break;
case 2:
y = read();
write(qchain(x, y)); nl;
break;
case 3:
z = read();
updsub(x, z);
break;
case 4:
write(qsub(x)); nl;
}
}
fclose(stdin), fclose(stdout);
return 0;
}