树链剖分(重链剖分)
Part 1 简单介绍
(1)前置知识
线段树、LCA
(2)基本定义
树链剖分是一种树路径维护算法
基本原理是 将一棵树划分成若干条连,用数据结构去维护每一条链,复杂度为O(logN)
本质是一些数据结构在树上的推广\
(3)方法
常见的路径剖分方式是轻重链剖分(启发式剖分)
(4)介绍几个概念
重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;(节点数目包括自身)
轻儿子:父亲节点中除了重儿子以外的儿子;
重边:父亲结点和重儿子连成的边;
轻边:父亲节点和轻儿子连成的边;
重链:由多条重边连接而成的路径;
轻链:由多条轻边连接而成的路径;
top节点:重链的起点;
可能不好理解,那么如图(先不已dfs序编号):
以这个图为例:
(1)先把每个节点的siz列出
siz[1] = 14;
siz[2] = 6;
siz[3] = 7;
siz[4] = 3;
siz[5] = 2;
siz[6] = 3;
siz[7] = 3;
siz[8 ~ 14] = 1;
(2)则每个节点的重儿子(我们用son表示)易知:
son[1] = 3;
son[2] = 4;
son[3] = 6 || 7;
son[4] = 8 || 9;
son[5] = 10;
son[6] = 11 || 12;
son[7] = 13 || 14;
son[8 ~ 14] = 0;
注:因为siz[6] == siz[7],所以3的重儿子为6或7,其余同理
把重儿子都标出来(蓝点):
把重链标出来(橙线):
把top节点标出来(红点):
相信你已经理解了(莫名自信)
Part 2 树链剖分的思路
将一棵每个节点的儿子按照siz[]的大小划分为重儿子和轻儿子,将一棵树划分为一条条链
利用dfs序,将同一个链上的点放在一起,链上的点编号都变成了连续的,那么就可以与区间问题扯上关系,也就是可以建立线段树了
在对两个点进行操作时,可以跳过很多节点(与倍增挺像的,可以类比LCA)
从而达到减小复杂度的目的
Part 3 Code实现
1. 先跑一遍dfs(初始化)
我们在这个dfs中要做的是:
初始化siz[], fath[], dep[];
三个数组顾名思义,不多解释了
上代码:
void dfs1(int now, int fa) {
fath[now] = fa; dep[now] = dep[fa] + 1; siz[now] = 1;//确定以x为根的子树的大小、父亲、深度
for(int i = head[now]; i; i = line[i].nxt) {
//链式前向星遍历边
int y = line[i].to;
if(y == fa) continue;//不能回到父节点吧?要不就死循环了,不信试试(划)
dfs1(y, now);
siz[now] += siz[y];//回溯时更新siz
if(siz[y] > siz[son[now]]) son[now] = y;//挑出重儿子
}
}
2. 再跑一边dfs(分链)
这一遍dfs我们要做的是:
- 确定dfs序,也就是id[];
- 赋值每个点的初始之到新的编号上;
- 处理每个点所在链的顶端top;
- 处理每条链;
dfs的顺序是先处理重儿子在处理轻儿子
上代码:
void dfs2(int now, int tp) {
id[x] = ++cnt; //标记每个点的新编号
value[cnt] = w[now]; //把每个点的初始值赋值到新编号上
top[now] = tp; //记录当前节点所在链的顶端
if(!son[now]) return; //没有儿子就返回
dfs2(son[now], tp); //先处理重儿子
for(int i = head[now]; i; i = line[i].nxt) { //链式前向星遍历边
int y = line[i].to;
if(y == fath[now] || y == son[now]) return; //不能回到父亲,并且在前面已经遍历过重儿子,所以也跳过
dfs2(y, y); //对于每个轻儿子都有一条从它本身开始的轻链
}
}
学长说过,不会就要手模
前面说dfs序是先处理重儿子在处理轻儿子
不懂哇!!!
那就上模拟!!!
因为顺序是先重再轻,所以每一条重链的编号是连续的
因为是dfs,所以每一个子树的新编号也是连续的
那现在!!
4. 处理问题
现在回顾一下我们要处理的问题
- 处理任意两点间路径上的点权和
- 处理一点及其子树的点权和
- 修改任意两点间路径上的点权
- 修改一点及其子树的点权
Problem 1
表示将树从 x 到 y 结点最短路径上所有节点的值都加上 k
inline void update(int now, int l, int r, int L, int R, int k) { //普通线段树区间修改
if(L <= l && R >= r) {lzy[now] += k, t[now] += k * (r - l + 1); return;}
if(lzy[now]) pushdown(now, r - l + 1);
int mid = l + r >> 1;
if(L <= mid) update(now << 1, l, mid, L, R, k);
if(R >= mid + 1) update(now << 1 | 1, r, mid + 1, L, R, k);
t[now] = t[now << 1] + t[now << 1 | 1];
}
inline void updrange(int x, int y, int k) {
while(top[x] != top[y]) { //如果要改的区间[x, y]中x节点与y节点不在同一个链上
if(dep[top[x]] < dep[top[y]]) swap(x, y); //保持x所在链的顶端的深度比y所在链的顶端的深度大
update(1, 1, n, id[top[x]], id[x], k); //修改x到x所在链的顶端这一段的点权
x = fa[top[x]]; //把x上移到x所在链顶端的那个点的上面一个点
//为什么要移到顶端的上面一个点?因为当前这个链整完了,
//所以这个操作相当于往上跳了一个链
}
if(dep[x] > dep[y]) swap(x, y); //把x放在上边
update(1, 1, n, id[x], id[y], k); //x和y已经在一条链上了,所以直接修改他俩之间所有点就行了
}
Problem 2
求树从 x 到 y 结点最短路径上所有节点的值之和
像线段树一样,查询与修改是相似的,可以参照修改代码中的注释理解一下查询。
inline void query(int rt, int l, int r, int L, int R) {
if(L <= l && r <= R) {res += a[rt]; return;}
if(lzy[rt]) pushdown(rt, len);
if(L <= mid) query(lson, L, R);
if(R > mid) query(rson, L, R);
}
inline int qrange(int x, int y) {
int ans = 0;
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
res = 0;
query(1, 1, n, id[top[x]], id[x]);
ans += res;
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
res = 0;
query(1, 1, n, id[x], id[y]);
ans += res;
return ans;
}
Problem 3
将以 x 为根节点的子树内所有节点值都加上 k
inline void updson(int x, int k) {
update(1, 1, n, id[x], id[x] + siz[x] - 1, k);
//因为siz[x]是包括本身在内的子树的节点个数,所以id[x] + siz[x] - 1是x的子树中节点编号最大的一个节点
//又因为x的子树中所有节点的编号都是连续的,所以直接修改[id[x], id[x] + siz[x] - 1]这个区间即可
}
Problem 4
求以 x 为根节点的子树内所有节点值之和
与修改同理
inline int qson(int x) {
res = 0;
query(1, 1, n, id[x], id[x] + siz[x] - 1);
return res;
}
Part 4 AC代码
/*
Worker:zcxxxxx
*/
#include<bits/stdc++.h>
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
#define len (r - l + 1)
#define mid ((l + r) >> 1)
using namespace std;
const int maxn = 1e5 + 7;
const int INF = 0x3f3f3f3f;
inline int read() {
int x = 0, f = 1; char c = getchar();
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
int n, m, r, mod, cnt, res;
int w[maxn], a[maxn << 2], lzy[maxn << 2];
struct edge {
int nxt, to;
}line[maxn << 1];
int head[maxn];
int wt[maxn], id[maxn], tim, fa[maxn], dep[maxn], siz[maxn], son[maxn], top[maxn];
/* ------------------------------------------------------------------------------- xian duanshu */
void pushdown(int rt, int lenn){
lzy[rt << 1] += lzy[rt];
lzy[rt << 1 | 1] += lzy[rt];
a[rt << 1] += lzy[rt] * (lenn - (lenn >> 1));
a[rt << 1 | 1] += lzy[rt] * (lenn >> 1);
a[rt<<1]%=mod;
a[rt<<1|1]%=mod;
lzy[rt] = 0;
}
inline void build(int rt, int l, int r) {
if(l == r) {
a[rt] = wt[l];
if(a[rt] > mod) a[rt] %= mod;
return;
}
build(lson);
build(rson);
a[rt] = (a[rt << 1] + a[rt << 1 | 1]) % mod;
}
inline void query(int rt, int l, int r, int L, int R) {
if(L <= l && r <= R) {res += a[rt]; res %= mod; return;}
if(lzy[rt]) pushdown(rt, len);
if(L <= mid) query(lson, L, R);
if(R > mid) query(rson, L, R);
}
inline void update(int rt, int l, int r, int L, int R, int c) {
if(L <= l && r <= R) {lzy[rt] += c; a[rt] += c * len;}
else {
if(lzy[rt]) pushdown(rt, len);
if(L <= mid) update(lson, L, R, c);
if(R > mid) update(rson, L, R, c);
a[rt] = (a[rt << 1] + a[rt << 1 | 1]) % mod;
}
}
/* ------------------------------------------------------------------------------- xian duan shu */
inline void add(int u, int v) {
line[++cnt].nxt = head[u];
line[cnt].to = v;
head[u] = cnt;
}
/* ------------------------------------------------------------------------------- */
inline int qrange(int x, int y) {
int ans = 0;
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
res = 0;
query(1, 1, n, id[top[x]], id[x]);
ans += res;
ans %= mod;
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
res = 0;
query(1, 1, n, id[x], id[y]);
ans += res;
return ans % mod;
}
inline void updrange(int x, int y, int k) {
k %= mod;
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
update(1, 1, n, id[top[x]], id[x], k);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
update(1, 1, n, id[x], id[y], k);
}
inline int qson(int x) {
res = 0;
query(1, 1, n, id[x], id[x] + siz[x] - 1);
return res;
}
inline void updson(int x, int k) {
update(1, 1, n, id[x], id[x] + siz[x] - 1, k);
}
inline int dfs1(int x, int f, int deep) {
dep[x] = deep; siz[x] = 1; fa[x] = f;
int maxson = -1;
for(int i = head[x]; i; i = line[i].nxt) {
int y = line[i].to;
if(y == f) continue;
dfs1(y, x, deep + 1);
siz[x] += siz[y];
if(siz[y] > maxson) son[x] = y, maxson = siz[y];
}
}
inline void dfs2(int x, int tp) {
id[x] = ++tim; wt[tim] = w[x]; top[x] = tp;
if(!son[x]) return;
dfs2(son[x], tp);
for(int i = head[x]; i; i = line[i].nxt) {
int y = line[i].to;
if(y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
/* ------------------------------------------------------------------------------- main */
int main() {
n = read(), m = read(), r = read(), mod = read();
for(int i = 1; i <= n; i++) w[i] = read();
for(int x, y, i = 1; i < n; i++) {
x = read(), y = read();
add(x, y); add(y, x);
}
dfs1(r, 0, 1);
dfs2(r, r);
build(1, 1, n);
for(int opt, x, y, z, i = 1; i <= m; i++) {
opt = read();
if(opt == 1) {
x = read(), y = read(), z = read();
updrange(x, y, z);
} else if(opt == 2) {
x = read(), y = read();
printf("%lld\n", qrange(x, y));
} else if(opt == 3) {
x = read(), z = read();
updson(x, z);
} else if(opt == 4) {
x = read();
printf("%lld\n", qson(x));
}
}
return 0;
}
注:
我们看到main函数中build建树是在dfs1与dfs2之后的
为什么呢?
因为dfs是为了初始化和建立dfs序,并且线段树要是根据这个dfs序进行修改查询的,所以要先初始化再建树
因为这个我调了一下午……