【数据结构】树链剖分详细讲解
“在一棵树上进行路径的修改、求极值、求和”乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。
树链剖分是把一棵树分割成若干条链,以便于维护信息的一种方法,其中最常用的是重链剖分(Heavy Path Decomposition,重路径分解),所以一般提到树链剖分或树剖都是指重链剖分。除此之外还有长链剖分和实链剖分等,本文暂不介绍。
首先我们需要明确概念:
- 重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;
- 轻儿子:父亲节点中除了重儿子以外的儿子;
- 重边:父亲结点和重儿子连成的边;
- 轻边:父亲节点和轻儿子连成的边;
- 重链:由多条重边连接而成的路径;
- 轻链:由多条轻边连接而成的路径;
我们定义树上一个节点的子节点中子树最大的一个为它的重子节点,其余的为轻子节点。一个节点连向其重子节点的边称为重边,连向轻子节点的边则为轻边。如果把根节点看作轻的,那么从每个轻节点出发,不断向下走重边,都对应了一条链,于是我们把树剖分成了 \(l\) 条链,其中 \(l\) 是轻节点的数量。
最近因为画图工具出了点问题,所以转载了Pecco学长的示意图(下面求LCA的方法的部分内容也来自Pecco学长)
剖分后的树(重链)有如下性质:
-
对于节点数为 \(n\) 的树,从任意节点向上走到根节点,经过的轻边数量不会超过 \(log\ n\)
这是因为当我们向下经过一条 轻边 时,所在子树的大小至少会除以二。所以说,对于树上的任意一条路径,把它拆分成从 \(lca\) 分别向两边往下走,分别最多走 \(O(\log n)\) 次,树上的每条路径都可以被拆分成不超过 \(O(\log n)\) 条重链。
-
树上每个节点都属于且仅属于一条重链 。
重链开头的结点不一定是重子节点(因为重边是对于每一个结点都有定义的)。所有的重链将整棵树 完全剖分 。
尽管树链部分看起来很难实现(的确有点繁琐),但我们可以用两个 DFS 来实现树链(树剖)。
相关伪码(来自 OI wiki)
第一个 DFS 记录每个结点的父节点(father)、深度(deep)、子树大小(size)、重子节点(hson)。
第二个 DFS 记录所在链的链顶(top,应初始化为结点本身)、重边优先遍历时的 DFS 序(dfn)、DFS 序对应的节点编号(rank)。
以下为代码实现。
我们先给出一些定义:
- \(fa(x)\) 表示节点 \(x\) 在树上的父亲(也就是父节点)。
- \(dep(x)\) 表示节点 \(x\) 在树上的深度。
- \(siz(x)\) 表示节点 \(x\) 的子树的节点个数。
- \(son(x)\) 表示节点 \(x\) 的 重儿子 。
- \(top(x)\) 表示节点 \(x\) 所在 重链 的顶部节点(深度最小)。
- \(dfn(x)\) 表示节点 \(x\) 的 DFS 序 ,也是其在线段树中的编号。
- \(rnk(x)\) 表示 DFS 序所对应的节点编号,有 \(rnk(dfn(x))=x\) 。
我们进行两遍 DFS 预处理出这些值,其中第一次 DFS 求出 \(fa(x)\) , \(dep(x)\) , \(siz(x)\) , \(son(x)\) ,第二次 DFS 求出 \(top(x)\) , \(dfn(x)\) , \(rnk(x)\) 。
// 当然树链写法不止一种,这个是我学习Oi wiki上知识点记录的模板代码
void dfs1(int o) {
son[o] = -1, siz[o] = 1;
for (int j = h[o]; j; j = nxt[j])
if (!dep[p[j]]) {
dep[p[j]] = dep[o] + 1;
fa[p[j]] = o;
dfs1(p[j]);
siz[o] += siz[p[j]];
if (son[o] == -1 || siz[p[j]] > siz[son[o]])
son[o] = p[j];
}
}
void dfs2(int o, int t) {
top[o] = t;
dfn[o] = ++cnt;
rnk[cnt] = o;
if (son[o] == -1)
return;
dfs2(son[o], t); // 优先对重儿子进行 DFS,可以保证同一条重链上的点 DFS 序连续
for (int j = h[o]; j; j = nxt[j])
if (p[j] != son[o] && p[j] != fa[o])
dfs2(p[j], p[j]);
}
// 写法2:来自Peocco学长,代码仅作学习使用
void dfs1(int p, int d = 1){
int Siz = 1,ma = 0;
dep[p] = d;
for(auto q : edges[p]){ // for循环写法和auto是C++11标准,竞赛可用
dfs1(q,d + 1);
fa[q] = p;
Siz += sz[q];
if(sz[q] > ma)
hson[p] = q, ma = sz[q];// hson = 重儿子
}
sz[p] = Siz;
}
// 需要先把根节点的top初始化为自身
void dfs2(int p){
for(auto q : edges[p]){
if(!top[q]){
if(q == hson[p])
top[q] = top[p];
else
top[q] = q;
dfs2(q);
}
}
}
以上这样便完成了剖分。
学习到这里想想开头的那句话:
“在一棵树上进行路径的修改、求极值、求和”乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。
如果不能一下想不到线段树解决不了的问题的话不如看看这道题 ↓
Hdu 3966 Aragorn's Story
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=3966
题意:给一棵树,并给定各个点权的值,然后有3种操作:
I C1 C2 K: 把C1与C2的路径上的所有点权值加上K
D C1 C2 K:把C1与C2的路径上的所有点权值减去K
Q C:查询节点编号为C的权值
分析:典型的树链剖分题目,先进行剖分,然后用线段树去维护即可
// Author : RioTian
// Time : 20/11/30
#include <bits/stdc++.h>
using namespace std;
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
typedef long long ll;
typedef int lld;
stack<int> ss;
const int maxn = 2e5 + 10;
const int inf = ~0u >> 2; // 1073741823
int M[maxn << 2];
int add[maxn << 2];
struct node {
int s, t, w, next;
} edges[maxn << 1];
int E, n;
int Size[maxn], fa[maxn], heavy[maxn], head[maxn], vis[maxn];
int dep[maxn], rev[maxn], num[maxn], cost[maxn], w[maxn];
int Seg_size;
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
void add_edge(int s, int t, int w) {
edges[E].w = w;
edges[E].s = s;
edges[E].t = t;
edges[E].next = head[s];
head[s] = E++;
}
void dfs(int u, int f) { //起点,父节点
int mx = -1, e = -1;
Size[u] = 1;
for (int i = head[u]; i != -1; i = edges[i].next) {
int v = edges[i].t;
if (v == f)
continue;
edges[i].w = edges[i ^ 1].w = w[v];
dep[v] = dep[u] + 1;
rev[v] = i ^ 1;
dfs(v, u);
Size[u] += Size[v];
if (Size[v] > mx)
mx = Size[v], e = i;
}
heavy[u] = e;
if (e != -1)
fa[edges[e].t] = u;
}
inline void pushup(int rt) {
M[rt] = M[rt << 1] + M[rt << 1 | 1];
}
void pushdown(int rt, int m) {
if (add[rt]) {
add[rt << 1] += add[rt];
add[rt << 1 | 1] += add[rt];
M[rt << 1] += add[rt] * (m - (m >> 1));
M[rt << 1 | 1] += add[rt] * (m >> 1);
add[rt] = 0;
}
}
void built(int l, int r, int rt) {
M[rt] = add[rt] = 0;
if (l == r)
return;
int m = (r + l) >> 1;
built(lson), built(rson);
}
void update(int L, int R, int val, int l, int r, int rt) {
if (L <= l && r <= R) {
M[rt] += val;
add[rt] += val;
return;
}
pushdown(rt, r - l + 1);
int m = (l + r) >> 1;
if (L <= m)
update(L, R, val, lson);
if (R > m)
update(L, R, val, rson);
pushup(rt);
}
lld query(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R)
return M[rt];
pushdown(rt, r - l + 1);
int m = (l + r) >> 1;
lld ret = 0;
if (L <= m)
ret += query(L, R, lson);
if (R > m)
ret += query(L, R, rson);
return ret;
}
void prepare() {
int i;
built(1, n, 1);
memset(num, -1, sizeof(num));
dep[0] = 0;
Seg_size = 0;
for (i = 0; i < n; i++)
fa[i] = i;
dfs(0, 0);
for (i = 0; i < n; i++) {
if (heavy[i] == -1) {
int pos = i;
while (pos && edges[heavy[edges[rev[pos]].t]].t == pos) {
int t = rev[pos];
num[t] = num[t ^ 1] = ++Seg_size;
// printf("pos=%d val=%d t=%d\n", Seg_size, edge[t].w, t);
update(Seg_size, Seg_size, edges[t].w, 1, n, 1);
pos = edges[t].t;
}
}
}
}
int lca(int u, int v) {
while (1) {
int a = find(u), b = find(v);
if (a == b)
return dep[u] < dep[v] ? u : v; // a,b在同一条重链上
else if (dep[a] >= dep[b])
u = edges[rev[a]].t;
else
v = edges[rev[b]].t;
}
}
void CH(int u, int lca, int val) {
while (u != lca) {
int r = rev[u]; // printf("r=%d\n",r);
if (num[r] == -1)
edges[r].w += val, u = edges[r].t;
else {
int p = fa[u];
if (dep[p] < dep[lca])
p = lca;
int l = num[r];
r = num[heavy[p]];
update(l, r, val, 1, n, 1);
u = p;
}
}
}
void change(int u, int v, int val) {
int p = lca(u, v);
// printf("p=%d\n",p);
CH(u, p, val);
CH(v, p, val);
if (p) {
int r = rev[p];
if (num[r] == -1) {
edges[r ^ 1].w += val; //在此处发现了我代码的重大bug
edges[r].w += val;
} else
update(num[r], num[r], val, 1, n, 1);
} //根节点,特判
else
w[p] += val;
}
lld solve(int u) {
if (!u)
return w[u]; //根节点,特判
else {
int r = rev[u];
if (num[r] == -1)
return edges[r].w;
else
return query(num[r], num[r], 1, n, 1);
}
}
int main() {
// freopen("in.txt", "r", stdin);
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int t, i, a, b, c, m, ca = 1, p;
while (cin >> n >> m >> p) {
memset(head, -1, sizeof(head));
E = 0;
for (int i = 0; i < n; ++i)
cin >> w[i];
for (int i = 0; i < m; ++i) {
cin >> a >> b;
a--, b--;
add_edge(a, b, 0), add_edge(b, a, 0);
}
prepare(); // 预处理
string op;
while (p--) {
cin >> op;
if (op[0] == 'I') { //区间添加
cin >> a >> b >> c;
a--, b--;
change(a, b, c);
} else if (op[0] == 'D') { //区间减少
cin >> a >> b >> c;
a--, b--;
change(a, b, -c);
} else { //查询
cin >> a;
a--;
cout << solve(a) << endl;
}
}
}
return 0;
}
由于数据很大,建议使用快读,而不是像我一样用 cin
(差了近500ms了)
折叠代码是千千dalao的解法:
Code
//千千dalao解法
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 50010;
struct Edge {
int to;
int next;
} edge[maxn << 1];
int head[maxn], tot; //链式前向星存储
int top[maxn]; // v所在重链的顶端节点
int fa[maxn]; //父亲节点
int deep[maxn]; //节点深度
int num[maxn]; //以v为根的子树节点数
int p[maxn]; // v与其父亲节点的连边在线段树中的位置
int fp[maxn]; //与p[]数组相反
int son[maxn]; //重儿子
int pos;
int w[maxn];
int ad[maxn << 2]; //树状数组
int n; //节点数目
void init() {
memset(head, -1, sizeof(head));
memset(son, -1, sizeof(son));
tot = 0;
pos = 1; //因为使用树状数组,所以我们pos初始值从1开始
}
void addedge(int u, int v) {
edge[tot].to = v;
edge[tot].next = head[u];
head[u] = tot++;
}
//第一遍dfs,求出 fa,deep,num,son (u为当前节点,pre为其父节点,d为深度)
void dfs1(int u, int pre, int d) {
deep[u] = d;
fa[u] = pre;
num[u] = 1;
//遍历u的邻接点
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if (v != pre) {
dfs1(v, u, d + 1);
num[u] += num[v];
if (son[u] == -1 || num[v] > num[son[u]]) //寻找重儿子
son[u] = v;
}
}
}
//第二遍dfs,求出 top,p
void dfs2(int u, int sp) {
top[u] = sp;
p[u] = pos++;
fp[p[u]] = u;
if (son[u] != -1) //如果当前点存在重儿子,继续延伸形成重链
dfs2(son[u], sp);
else
return;
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if (v != son[u] && v != fa[u]) //遍历所有轻儿子新建重链
dfs2(v, v);
}
}
int lowbit(int x) {
return x & -x;
}
//查询
int query(int i) {
int s = 0;
while (i > 0) {
s += ad[i];
i -= lowbit(i);
}
return s;
}
//增加
void add(int i, int val) {
while (i <= n) {
ad[i] += val;
i += lowbit(i);
}
}
void update(int u, int v, int val) {
int f1 = top[u], f2 = top[v];
while (f1 != f2) {
if (deep[f1] < deep[f2]) {
swap(f1, f2);
swap(u, v);
}
//因为区间减法成立,所以我们把对某个区间[f1,u]
//的更新拆分为 [0,f1] 和 [0,u] 的操作
add(p[f1], val);
add(p[u] + 1, -val);
u = fa[f1];
f1 = top[u];
}
if (deep[u] > deep[v])
swap(u, v);
add(p[u], val);
add(p[v] + 1, -val);
}
int main() {
ios::sync_with_stdio(false);
int m, ps;
while (cin >> n >> m >> ps) {
int a, b, c;
for (int i = 1; i <= n; i++)
cin >> w[i];
init();
for (int i = 0; i < m; i++) {
cin >> a >> b;
addedge(a, b);
addedge(b, a);
}
dfs1(1, 0, 0);
dfs2(1, 1);
memset(ad, 0, sizeof(ad));
for (int i = 1; i <= n; i++) {
add(p[i], w[i]);
add(p[i] + 1, -w[i]);
}
for (int i = 0; i < ps; i++) {
char op;
cin >> op;
if (op == 'Q') {
cin >> a;
cout << query(p[a]) << endl;
} else {
cin >> a >> b >> c;
if (op == 'D')
c = -c;
update(a, b, c);
}
}
}
return 0;
}
利用树链求LCA
这个部分参考了Peocco学长,十分感谢
在这道经典题中,求了LCA,但为什么树剖就可以求LCA呢?
树剖可以单次 \(O(log\ n)\)! 地求LCA,且常数较小。假如我们要求两个节点的LCA,如果它们在同一条链上,那直接输出深度较小的那个节点就可以了。
否则,LCA要么在链头深度较小的那条链上,要么就是两个链头的父节点的LCA,但绝不可能在链头深度较大的那条链上[1]。所以我们可以直接把链头深度较大的节点用其链头的父节点代替,然后继续求它与另一者的LCA。
由于在链上我们可以 \(O(1)\) 地跳转,每条链间由轻边连接,而经过轻边的次数又不超过 ,所以我们实现了 \(O(log\ n)\) 的LCA查询。
int lca(int a, int b) {
while (top[a] != top[b]) {
if (dep[top[a]] > dep[top[b]])
a = fa[top[a]];
else
b = fa[top[b]];
}
return (dep[a] > dep[b] ? b : a);
}
结合数据结构
在进行了树链剖分后,我们便可以配合线段树等数据结构维护树上的信息,这需要我们改一下第二次 DFS 的代码,我们用dfsn
数组记录每个点的dfs序,用madfsn
数组记录每棵子树的最大dfs序:(这里有点像连通图的知识了)
// 需要先把根节点的top初始化为自身
int cnt;
void dfs2(int p) {
madfsn[p] = dfsn[p] = ++cnt;
if (hson[p] != 0) {
top[hson[p]] = top[p];
dfs2(hson[p]);
madfsn[p] = max(madfsn[p], madfsn[hson[p]]);
}
for (auto q : edges[p])
if (!top[q]) {
top[q] = q;
dfs2(q);
madfsn[p] = max(madfsn[p], madfsn[q]);
}
}
注意到,每棵子树的dfs序都是连续的,且根节点dfs序最小;而且,如果我们优先遍历重子节点,那么同一条链上的节点的dfs序也是连续的,且链头节点dfs序最小。
所以就可以用线段树等数据结构维护区间信息(以点权的和为例),例如路径修改(类似于求LCA的过程):
void update_path(int x, int y, int z) {
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) {
update(dfsn[top[x]], dfsn[x], z);
x = fa[top[x]];
} else {
update(dfsn[top[y]], dfsn[y], z);
y = fa[top[y]];
}
}
if (dep[x] > dep[y])
update(dfsn[y], dfsn[x], z);
else
update(dfsn[x], dfsn[y], z);
}
路径查询:
int query_path(int x, int y) {
int ans = 0;
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) {
ans += query(dfsn[top[x]], dfsn[x]);
x = fa[top[x]];
} else {
ans += query(dfsn[top[y]], dfsn[y]);
y = fa[top[y]];
}
}
if (dep[x] > dep[y])
ans += query(dfsn[y], dfsn[x]);
else
ans += query(dfsn[x], dfsn[y]);
return ans;
}
子树修改(更新):
void update_subtree(int x, int z){
update(dfsn[x], madfsn[x], z);
}
子树查询:
int query_subtree(int x){
return query(dfsn[x], madfsn[x]);
}
需要注意,建线段树的时候不是按节点编号建,而是按dfs序建,类似这样:
for (int i = 1; i <= n; ++i)
B[i] = read();
// ...
for (int i = 1; i <= n; ++i)
A[dfsn[i]] = B[i];
build();
当然,不仅可以用线段树维护,有些题也可以使用珂朵莉树等数据结构(要求数据不卡珂朵莉树,如这道)。此外,如果需要维护的是边权而不是点权,把每条边的边权下放到深度较深的那个节点处即可,但是查询、修改的时候要注意略过最后一个点。
写在最后:
OI wiki上有一些推荐做的列题,但每个都需要比较多的时间+耐心去完成,所以这里推荐几个必做的题:
SPOJ QTREE – Query on a tree (树链剖分):千千dalao的题解报告
HDU 3966 Aragorn’s Story (树链剖分):建议先看一遍我的解法再独立完成。
参考
洛谷日报:https://zhuanlan.zhihu.com/p/41082337
OI wiki:https://oi-wiki.org/graph/hld/
Pecco学长:https://www.zhihu.com/people/one-seventh
千千:https://www.dreamwings.cn/hdu3966/4798.html
设top[a]的深度≤top[b]的深度,且c=lca(a,b)在b所在的链上;那么c是a和b的祖先且c的深度≥top[b]的深度,那么c的深度≥top[a]的深度。c是a的祖先,top[a]也是a的祖先,c的深度大于等于top[a],那c必然在连接top[a]和a的这条链上,与前提矛盾 ↩︎