[算法入门]树链剖分 - 轻重链剖分
#0.0 前置知识
为了保证本文的阅读愉快,建议熟练掌握以下知识:
#1.0 啥是树剖 & 树剖能干啥
树链剖分用于将树分割成若干条链的形式,以维护树上路径的信息。
具体来说,将整棵树剖分为若干条链,使它组合成线性结构,然后用其他的数据结构维护信息。
重链剖分可以将树上的任意一条路径划分成不超过 \(\log n\) 条连续的链,每条链上的点深度互不相同(即是自底向上的一条链,链上所有点的 LCA 为链的一个端点)。
重链剖分还能保证划分出的每条链上的节点 DFS 序连续,因此可以方便地用一些维护序列的数据结构(如线段树)来维护树上路径的信息。
如:
- 修改 树上两点之间的路径上 所有点的值。
- 查询 树上两点之间的路径上 节点权值的 和/极值/其它(在序列上可以用数据结构维护,便于合并的信息)。
除了配合数据结构来维护树上路径信息,树剖还可以用来 (且常数较小)地求 LCA。在某些题目中,还可以利用其性质来灵活地运用树剖。
#2.0 重链剖分
正如文题,本文主要讲述重链剖分的实现
#2.1 基本变量 & 定义
重儿子: 所有儿子中子树结点数最多的儿子
重边:由结点与该点重儿子组成的树边
重链:由若干条重边首尾相连组成的长链
depth[x]
结点 \(x\) 的深度f[x]
结点 \(x\) 的父结点编号son[x]
结点 \(x\) 的重儿子编号size[x]
以 \(x\) 为根的结点个数(包括自己)top[x]
\(x\) 所在重链的顶部结点id[x],rk[y]
字典(?)数组,将原树中的结点映射到线段树中的结点与它的逆操作
#2.2 读入 & 储存
我们这里采用邻接表储存双向边(每一条树边存两个方向)
#2.3 解剖的实现 - 1
这是树剖的核心操作一
这一部分,我们主要是要完成以下任务:
- 搞清出每个结点的深度
- 摸清父子关系,统计出以以 \(x\) 为根的结点个数
size[x]
- 找到重儿子
很显然,以上三个任务可以通过一次 dfs 解决
void dfs1(int x,int fa,int depth){
f[x] = fa;
d[x] = depth;
size[x] = 1;
for (int i = head[x];i != -1;i = e[i].next){
int v = e[i].v;
if (v == fa) //注意,因为是双向边,会有连向父结点的边
continue;
dfs1(v,x,depth + 1); //递归求解子结点
size[x] += size[v];
if (size[v] > size[son[x]]) //更新重儿子
son[x] = v;
}
}
#2.4 解剖的实现 - 2
#2.4.1 目标及实现
这次,我们要完成的任务有
- 确定所在重链的头
- 捋清与线段树上结点编号的关系
显然,依旧可以采用 dfs 实现,这里有几个值得强调的点:
- 我们将一个落单的点看做一个新的重链,并将它作为以他为根的重儿子所在重链的头
- 优先递归重儿子(原因在下面)
void dfs2(int u,int t){
top[u] = t;
id[u] = tot;
rk[tot ++] = u;
if (!son[u]) //叶子结点,直接返回
return;
dfs2(son[u],t); //优先递归重儿子
for (int i = head[u];i != -1;i = e[i].next){
int v= e[i].v;
if (v != son[u] && v != f[u])
dfs2(v,v); //注意不要重复递归与错误递归
}
}
#2.4.2 对于某些问题の解释
Q:为啥优先递归重儿子?
别着急,我们先来看看如果这样做会有怎样的效果
显然,这样做一条重链上的所有结点对应的编号连续(dfs 序连续),这样的话,如果我们要按照 dfs 序建立一颗线段树,那么一条重链对应的结点编号显然是连续的,我们如果要对一条重链进行修改,可以直接修改一个区间
#3.0 求 LCA
好的,经过以上两场手术,一颗完整的树已经被解剖成为大大小小的链了,在进行更深一步的探索前,我们先来学习一个必备的操作——用重链求树上两点的 LCA
#3.1 简单分析
首先,我们考虑什么情况可以直接返回 \(x,y\) 的 LCA?
显然是二者在同一条重链上时,那么,为了达到这个目标,我们只需要每次让所在重链的头深度较大的结点向上跳,跳到该重链的头的父节点,重复这个过程即可
#3.2 实现
inline int lca(int a,int b){
while (top[a] != top[b]){
if (d[top[a]] < d[top[b]])
swap(a,b);
a = f[top[a]];
}
if (d[a] < d[b])
swap(a,b);
return b;
}
#4.0 利用数据结构维护信息
这里以线段树为例(因为我不会其他的)
假如我们要维护这样一棵树:
- 可为从 \(a\) 到 \(b\) 的路径上的每个结点加 \(x\)
- 为以 \(a\) 为根的子树上的所有结点加 \(x\)
- 求从 \(a\) 到 \(b\) 的路径上的每个结点的权值和
- 求 \(a\) 为根的子树上的所有结点的权值和
让我们一样一样来
#4.1 路径修改
首先,从 \(a\) 到 \(b\) 的路径是什么?
是不是就是从 \(a\) 到 \(\text{LCA}(a,b)\),再从 \(\text{LCA}(a,b)\) 到 \(b\),那么,我们可不可以借用上面求 LCA 的思想,每跳一次,给重链进行区间修改,最终达到修改整个路径
还记得我们解剖时的一个操作吗?让一条重链上的结点对应的线段树上的编号连续,所以就可以直接进行区间修改
inline void update(int l,int r,int x,int k){
if (l <= a[k].l && r >= a[k].r){
add(k,x);
return;
}
pushdown(k);
int mid = (a[k].l + a[k].r) >> 1;
if (mid >= l)
update(l,r,x,a[k].ls);
if (mid < r)
update(l,r,x,a[k].rs);
pushup(k);
}
inline void updates(int x,int y,int c){
while (top[x] != top[y]){
if ( d[top[x]] < d[top[y]])
swap(x,y);
update(id[top[x]],id[x],c,rt);
x = f[top[x]];
}
if (id[x] > id[y])
swap(x,y);
update(id[x],id[y],c,rt);
}
#4.2 子树修改
首先,dfs 序具有这样一个性质:一颗子树上的点 dfs 序连续
所以我们就可以直接采用如下代码进行区间修改
update(id[x],id[x] + size[x] - 1,y,rt);
#4.3 路径查询
与路径修改一样的思想
inline int query(int l,int r,int k){
if (l <= a[k].l && a[k].r <= r)
return a[k].sum;
pushdown(k);
int mid = (a[k].l + a[k].r) >> 1;
int ans = 0;
if (mid >= l)
(ans += query(l,r,a[k].ls)) %= mod;
if (mid < r)
(ans += query(l,r,a[k].rs)) %= mod;
return ans;
}
inline int sum(int x,int y){
int ans = 0;
while (top[x] != top[y]){
if (d[top[x]] < d[top[y]])
swap(x,y);
(ans += query(id[top[x]],id[x],rt)) %= mod;
x = f[top[x]];
}
if (id[x] > id[y])
swap(x,y);
return (ans + query(id[x],id[y],rt)) % mod;
}
#4.4 子树查询
与子树修改类似
#5.0 例题
板子题
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <queue>
#define INF 0x3fffffff
#define N 100100
#define ll long long
#define mset(l,x) memset(l,x,sizeof(l))
#define mp(a,b) make_pair(a,b)
using namespace std;
struct Edge{
int u;
int v;
int next;
};
Edge e[N * 4];
struct Node{
int l,r; //区间左右端点
int ls,rs; //左右子树根节点编号
int sum; //区间和
int lazy;
};
Node a[N * 4];
int n,m,root,mod;
int value[N],head[N];
int d[N],f[N],size[N],son[N],tot;
int top[N],id[N],rk[N],rt;
inline void add_edge(int u,int v){
e[tot].u = u;
e[tot].v = v;
e[tot].next = head[u];
head[u] = tot ++;
}
void dfs1(int x,int fa,int depth){
f[x] = fa;
d[x] = depth;
size[x] = 1;
for (int i = head[x];i != -1;i = e[i].next){
int v = e[i].v;
if (v == fa)
continue;
dfs1(v,x,depth + 1);
size[x] += size[v];
if (size[v] > size[son[x]])
son[x] = v;
}
}
void dfs2(int u,int t){
top[u] = t;
id[u] = tot;
rk[tot ++] = u;
if (!son[u])
return;
dfs2(son[u],t);
for (int i = head[u];i != -1;i = e[i].next){
int v= e[i].v;
if (v != son[u] && v != f[u])
dfs2(v,v);
}
}
inline int len(int k){
return a[k].r - a[k].l + 1;
}
inline void pushup(int k){
a[k].sum = (a[a[k].ls].sum + a[a[k].rs].sum) % mod;
}
inline void add(int k,int x){
(a[k].lazy += x) %= mod;
(a[k].sum += len(k) * x) %= mod;
}
inline void build(int l,int r,int k){
if (l == r){
a[k].sum = value[rk[l]];
a[k].l = a[k].r = l;
return;
}
int mid = (l + r) >> 1;
a[k].ls = tot ++;
build(l,mid,a[k].ls);
a[k].rs = tot ++;
build(mid + 1,r,a[k].rs);
a[k].l = a[a[k].ls].l,a[k].r = a[a[k].rs].r;
pushup(k);
}
inline void pushdown(int k){
if (a[k].lazy){
int ls = a[k].ls,rs = a[k].rs;
add(ls,a[k].lazy);
add(rs,a[k].lazy);
a[k].lazy = 0;
}
}
inline void update(int l,int r,int x,int k){
if (l <= a[k].l && r >= a[k].r){
add(k,x);
return;
}
pushdown(k);
int mid = (a[k].l + a[k].r) >> 1;
if (mid >= l)
update(l,r,x,a[k].ls);
if (mid < r)
update(l,r,x,a[k].rs);
pushup(k);
}
inline int query(int l,int r,int k){
if (l <= a[k].l && a[k].r <= r)
return a[k].sum;
pushdown(k);
int mid = (a[k].l + a[k].r) >> 1;
int ans = 0;
if (mid >= l)
(ans += query(l,r,a[k].ls)) %= mod;
if (mid < r)
(ans += query(l,r,a[k].rs)) %= mod;
return ans;
}
inline int sum(int x,int y){
int ans = 0;
while (top[x] != top[y]){
if (d[top[x]] < d[top[y]])
swap(x,y);
(ans += query(id[top[x]],id[x],rt)) %= mod;
x = f[top[x]];
}
if (id[x] > id[y])
swap(x,y);
return (ans + query(id[x],id[y],rt)) % mod;
}
inline void updates(int x,int y,int c){
while (top[x] != top[y]){
if ( d[top[x]] < d[top[y]])
swap(x,y);
update(id[top[x]],id[x],c,rt);
x = f[top[x]];
}
if (id[x] > id[y])
swap(x,y);
update(id[x],id[y],c,rt);
}
int main(){
mset(head,-1);
scanf("%d%d%d%d",&n,&m,&root,&mod);
for (int i = 1;i <= n;i ++)
scanf("%d",&value[i]);
for (int i = 1;i < n;i ++){
int u,v;
scanf("%d%d",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
dfs1(root,0,1);
tot = 1;
dfs2(root,root);
tot = 0;
build(1,n,rt = tot ++);
while (m --){
int s;
scanf("%d",&s);
if (s == 1){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
updates(x,y,z);
}
else if (s == 2){
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",sum(x,y));
}
else if (s == 3){
int x,y;
scanf("%d%d",&x,&y);
update(id[x],id[x] + size[x] - 1,y,rt);
}
else{
int x;
scanf("%d",&x);
printf("%d\n",query(id[x],id[x] + size[x] - 1,rt));
}
}
return 0;
}