树链剖分
基本概念
-
树链剖分:通过特定方法将一棵树剖分成多条不相交的链,以达到优化暴力时间复杂度的效果,通常使用 轻重链剖分 实现。树链剖分支持四种操作:
-
修改树中结点 \(x\) 到结点 \(y\) 的最短路径上所有结点的点权
-
查询树中结点 \(x\) 到结点 \(y\) 的最短路径上所有节点的点权之和
-
修改树中结点 \(x\) 及其子树的点权
-
查询树中结点 \(x\) 及其子树的点权之和
-
-
重儿子:假设结点 \(u\) 有若干子结点,其中子树大小最大的结点称为 \(u\) 的重儿子。
-
轻儿子:\(u\) 的子结点中,除了重儿子,都是轻儿子。
-
重边:连接结点 \(u\) 与其重儿子的边。
-
轻边:连接结点 \(u\) 与其轻儿子的边。
-
重链:由若干条重边组成的路径。
-
轻链:由若干条轻边组成的路径。
值得注意的是一些树链剖分的性质(建议最后阅读此部分内容):
-
一棵树中可以有多条重链、多条轻链
-
轻儿子也可以有其重儿子
-
若一个结点 \(u\) 有儿子,则 \(u\) 有且仅有一个重儿子
-
叶子结点没有重儿子
-
每条重链的链首一定是某个结点的轻儿子
-
轻儿子可以看作是一条自身到自身的长度为 \(1\) 的重链。
-
树链剖分的过程可以看成是两个结点 交替地 呈 “人”字形往上跳(当然人的两笔不一定相交)
-
从根节点到叶子节点的轻边条数和重链条数各不超过 \(logn\) 条
算法思想
算法组成
树链剖分主要使用 \(dfs\) 序和 线段树 来实现。显然,一个结点 \(u\) 及其子树在 \(dfs\) 序中一定相邻。我们可以借助这个性质来进行树链剖分。
树链剖分算法主要由三个部分组成:两个 \(dfs\) 函数和一棵线段树。其中 \(dfs1\) 函数需要完成:
-
预处理每个结点的父结点 \(fa\)
-
预处理每个结点的深度 \(dep\)
-
预处理每个结点的子树大小 \(size\)
-
预处理每个结点的重儿子 \(son\)
\(dfs2\) 函数需要完成:
-
预处理每个结点的 \(dfs\) 序 \(id\)
-
完成 \(dfs\) 序到其对应结点的映射 \(rk\)
-
记录每个结点所在重链的链首 \(top\)
而线段树则是用来维护 \(dfs\) 序中重链和子树的点权和,在下文会提及。两个 \(dfs\) 函数的实现并不难,读者可以尝试自行完成。此处提出一个要求:\(dfs2\) 函数必须先遍历重儿子,再遍历轻儿子,以保证一条重链中的结点在 \(dfs\) 序中是连续的,原因下文会提及。
操作处理
修改路径
修改路径和查询路径的思想类似于倍增的 \(LCA\) 。
首先假设待修改的结点 \(x\) 和 \(y\) 不属于同一条重链。如果 \(x\) 的链首的深度大于 \(y\) 的链首的深度,即 \(dep(top_x) \geq dep(top_y)\) ,则令 \(x\) 跳到 \(x\) 的链首的父结点,即 \(x = fa(top_x)\) ,并令线段树中从 \(x\) 的链首到 \(x\) 的区间加上修改的权值 \(w\)。否则,令 \(y\) 跳到 \(y\) 的链首的父结点,并令线段树中 \(y\) 的链首到 \(y\) 的区间加上 \(w\) 。
显然,如果重复上述过程,\(x\) 和 \(y\) 最终会到达同一条重链。设 \(x\) 是 \(x\) 和 \(y\) 中深度较小的结点,此时,将线段树中从 \(x\) 到 \(y\)的区间加上 \(x\),修改结束。
此时为什么要先遍历重儿子的原因就显而易见了。因为我们需要求出重链的权值和,又因为线段树只能维护连续区间的和,所以我们会先遍历重儿子,令一整条重链在 \(dfs\) 序中连续。
每次修改时间复杂度 \(O(log^2n)\)。
void add_path(int x, int y, int z)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
swap(x, y);
update(1, id[top[x]], id[x], z);
x = fa[top[x]];
}
if (dep[x] > dep[y])
swap(x, y);
update(1, id[x], id[y], z);
}
查询路径
查询路径的过程同修改路径基本一致,同样是借助 \(LCA\) 的思想。只不过当 \(dep(top_x) \geq dep(top_y)\) 时,路径权值总和 \(sum\) 就加上线段树中从 \(x\) 的链首开始到 \(x\) 的区间的权值总和,反之同理。
每次查询时间复杂度 \(O(log^2n)\) 。
int query_path(int x, int y)
{
int sum = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
swap(x, y);
sum = (sum + query(1, id[top[x]], id[x])) % p;
x = fa[top[x]];
}
if (dep[x] > dep[y])
swap(x, y);
sum = (sum + query(1, id[x], id[y])) % p;
return sum % p;
}
修改子树
先明确一个显然的性质:\(dfs\) 序中一个结点 \(u\) 和其子树是连续的,且这个区间的起点为 \(id_u\) ,终点为 \(id_u + size_u - 1\) 。有了这条性质,我们就可以方便地修改:直接令线段树中从 \(id_u\) 开始到 \(id_u + size_u - 1\) 的区间加上修改的权值 \(w\) 即可。
每次修改时间复杂度 \(O(logn)\)。
void update(int k, int l, int r, int x)
{
if (tree[k].l >= l && tree[k].r <= r)
{
tree[k].sum = (tree[k].sum + (tree[k].r - tree[k].l + 1) * x) % p;
tree[k].lazy += x;
return;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
if (l <= mid)
update(2 * k, l, r, x);
if (r > mid)
update(2 * k + 1, l, r, x);
push_up(k);
}
scanf("%d%d", &x, &z);
update(1, id[x], id[x] + size[x] - 1, z);
查询子树
基本思想同上,直接查询线段树中 \(id_u\) 到 \(id_u + size_u - 1\) 的区间总和即可。
int query(int k, int l, int r)
{
if (tree[k].l >= l && tree[k].r <= r)
return tree[k].sum;
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2, sum = 0;
if (l <= mid)
sum = (sum + query(2 * k, l, r)) % p;
if (r > mid)
sum = (sum + query(2 * k + 1, l, r)) % p;
return sum % p;
}
scanf("%d", &x);
printf("%d\n", query(1, id[x], id[x] + size[x] - 1) % p);
每次查询时间复杂度 \(O(logn)\) 。
模板
#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 5;
const int maxm = 2e5 + 5;
struct Edge
{
int to, nxt;
} edge[maxm];
struct node
{
int l, r, lazy, sum;
} tree[5 * maxn];
int n, m, r, p, cnt, tot;
int head[maxn], id[maxn], rk[maxn], son[maxn], top[maxn];
int w[maxn], dep[maxn], fa[maxn], size[maxn];
void add_edge(int u, int v)
{
cnt++;
edge[cnt].to = v;
edge[cnt].nxt = head[u];
head[u] = cnt;
}
void dfs1(int u, int f)
{
fa[u] = f;
dep[u] = dep[f] + 1;
size[u] = 1;
for (int i = head[u]; i; i = edge[i].nxt)
{
if (edge[i].to != f)
{
dfs1(edge[i].to, u);
size[u] += size[edge[i].to];
if (size[edge[i].to] > size[son[u]])
son[u] = edge[i].to;
}
}
}
void dfs2(int u, int t)
{
id[u] = ++tot;
rk[tot] = u;
top[u] = t;
if (son[u])
dfs2(son[u], t);
for (int i = head[u]; i; i = edge[i].nxt)
if (edge[i].to != fa[u] && edge[i].to != son[u])
dfs2(edge[i].to, edge[i].to);
}
void push_up(int k)
{
tree[k].sum = (tree[2 * k].sum + tree[2 * k + 1].sum) % p;
}
void push_down(int k)
{
if (tree[k].l == tree[k].r)
{
tree[k].lazy = 0;
return;
}
tree[2 * k].sum = (tree[2 * k].sum + (tree[2 * k].r - tree[2 * k].l + 1) * tree[k].lazy) % p;
tree[2 * k + 1].sum = (tree[2 * k + 1].sum + (tree[2 * k + 1].r - tree[2 * k + 1].l + 1) * tree[k].lazy) % p;
tree[2 * k].lazy += tree[k].lazy;
tree[2 * k + 1].lazy += tree[k].lazy;
tree[k].lazy = 0;
}
void build(int k, int l, int r)
{
tree[k].l = l;
tree[k].r = r;
if (l == r)
{
tree[k].sum = w[rk[l]];
return;
}
int mid = (l + r) / 2;
build(2 * k, l, mid);
build(2 * k + 1, mid + 1, r);
push_up(k);
}
void update(int k, int l, int r, int x)
{
if (tree[k].l >= l && tree[k].r <= r)
{
tree[k].sum = (tree[k].sum + (tree[k].r - tree[k].l + 1) * x) % p;
tree[k].lazy += x;
return;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
if (l <= mid)
update(2 * k, l, r, x);
if (r > mid)
update(2 * k + 1, l, r, x);
push_up(k);
}
int query(int k, int l, int r)
{
if (tree[k].l >= l && tree[k].r <= r)
return tree[k].sum;
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2, sum = 0;
if (l <= mid)
sum = (sum + query(2 * k, l, r)) % p;
if (r > mid)
sum = (sum + query(2 * k + 1, l, r)) % p;
return sum % p;
}
void add_path(int x, int y, int z)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
swap(x, y);
update(1, id[top[x]], id[x], z);
x = fa[top[x]];
}
if (dep[x] > dep[y])
swap(x, y);
update(1, id[x], id[y], z);
}
int query_path(int x, int y)
{
int sum = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
swap(x, y);
sum = (sum + query(1, id[top[x]], id[x])) % p;
x = fa[top[x]];
}
if (dep[x] > dep[y])
swap(x, y);
sum = (sum + query(1, id[x], id[y])) % p;
return sum % p;
}
int main()
{
int opt, x, y, z;
scanf("%d%d%d%d", &n, &m, &r, &p);
for (int i = 1; i <= n; i++)
scanf("%d", &w[i]);
for (int i = 1; i <= n - 1; i++)
{
scanf("%d%d", &x, &y);
add_edge(x, y);
add_edge(y, x);
}
dfs1(r, 0);
dfs2(r, 0);
build(1, 1, n);
for (int i = 1; i <= m; i++)
{
scanf("%d", &opt);
if (opt == 1)
{
scanf("%d%d%d", &x, &y, &z);
add_path(x, y, z);
}
else if (opt == 2)
{
scanf("%d%d", &x, &y);
printf("%d\n", query_path(x, y) % p);
}
else if (opt == 3)
{
scanf("%d%d", &x, &z);
update(1, id[x], id[x] + size[x] - 1, z);
}
else
{
scanf("%d", &x);
printf("%d\n", query(1, id[x], id[x] + size[x] - 1) % p);
}
}
return 0;
}
例题选讲
\(LCA\)
给定一棵树,试求树中两结点 \(x, y\) 的最近公共祖先。
\(LCA\) 问题同样可以使用树剖实现,具体流程同普通树剖基本一致:令链首深度大的那个结点跳到其链首的父结点,直到两个结点在同一条重链内为止。此时深度较小的那个结点就是结点 \(x, y\) 的最近公共祖先。原理显然,不再赘述。
int lca(int x, int y)
{
while (top[x] != top[y])
{
if (dep[top[x]] <= dep[top[y]])
swap(x, y);
x = fa[top[x]];
}
if (dep[x] >= dep[y])
swap(x, y);
return x;
}
边权树剖
给定一棵有 \(n\) 个结点的树,请写出一份代码维护以下两种操作:
-
将第 \(i\) 条边的边权修改为 \(w\)
-
查询从 \(u\) 到 \(v\) 的路径上最大的边权
对于第二种操作,若 \(u = v\),则结果为 \(0\)。
结点总数 \(n \leq 10^5\),操作总数 \(m \leq 3 \times 10^5\)
这道题与树剖模板的唯一区别在于,传统树剖维护的是点权,而这种树剖维护的是边权。如果按照维护边权的方法来维护,代码的复杂程度和时间效率都会大幅度降低。因此,我们需要考虑把边权转化成点权。
这里利用一个小技巧:对于一个结点 \(i\) 来说,它的点权等于其上方的边的边权。即设 \(i\) 的父结点为 \(u\),则 \(i\) 的点权等于 \(w_{u, i}\)。特殊地,根结点没有点权。
这样,我们就可以通过传统的树剖方法来维护边权。但是,修改和查询操作需要相应地做出改变。查询 \(x\) 到 \(y\) 的路径中最大边权时,不可以将 \(x\) 和 \(y\) 的最近公共祖先也一起比较,这样会导致最近公共祖先上方的边权也考虑进来。
至于在代码中如何避免查询到最近公共祖先,这里利用一个树剖的性质:当 \(x\) 和 \(y\) 处于同一条重链时,深度较小的那一个结点就是查询的原始结点 \(x\) 和 \(y\) 的最近公共祖先。设 \(x\) 是深度较小的结点,则我们查询的区间可以从 \(x\) 的后一个结点开始,这样就不会查询到 \(lca\)。
#include <cstdio>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 5;
const int maxm = 2e5 + 5;
struct Edge
{
int from, to, nxt, w, id;
} edge[maxm];
struct node
{
int l, r, val;
} tree[5 * maxn];
int n, cnt;
int head[maxn], pos[maxn], rk[maxn], son[maxn];
int w[maxn], size[maxn], dep[maxn], fa[maxn], top[maxn];
char opt[10];
void add_edge(int u, int v, int w)
{
cnt++;
edge[cnt].from = u;
edge[cnt].to = v;
edge[cnt].w = w;
edge[cnt].nxt = head[u];
head[u] = cnt;
}
void dfs(int u, int f)
{
for (int i = head[u]; i; i = edge[i].nxt)
{
if (edge[i].to != f)
{
w[edge[i].to] = edge[i].w;
dfs(edge[i].to, u);
}
}
}
void dfs1(int u, int f)
{
fa[u] = f;
dep[u] = dep[f] + 1;
size[u] = 1;
for (int i = head[u]; i; i = edge[i].nxt)
{
if (edge[i].to != f)
{
dfs1(edge[i].to, u);
size[u] += size[edge[i].to];
if (size[edge[i].to] > size[son[u]])
son[u] = edge[i].to;
}
}
}
void dfs2(int u, int t)
{
cnt++;
pos[u] = cnt;
rk[cnt] = u;
top[u] = t;
if (son[u])
dfs2(son[u], t);
for (int i = head[u]; i; i = edge[i].nxt)
if (edge[i].to != fa[u] && edge[i].to != son[u])
dfs2(edge[i].to, edge[i].to);
}
void push_up(int k)
{
tree[k].val = max(tree[2 * k].val, tree[2 * k + 1].val);
}
void build(int k, int l, int r)
{
tree[k].l = l;
tree[k].r = r;
if (l == r)
{
tree[k].val = w[rk[l]];
return;
}
int mid = (l + r) / 2;
build(2 * k, l, mid);
build(2 * k + 1, mid + 1, r);
push_up(k);
}
void update(int k, int x, int w)
{
if (tree[k].l == tree[k].r)
{
tree[k].val = w;
return;
}
int mid = (tree[k].l + tree[k].r) / 2;
if (x <= mid)
update(2 * k, x, w);
else
update(2 * k + 1, x, w);
push_up(k);
}
int query(int k, int l, int r)
{
if (tree[k].l >= l && tree[k].r <= r)
return tree[k].val;
int mid = (tree[k].l + tree[k].r) / 2, val = 0;
if (l <= mid)
val = max(val, query(2 * k, l, r));
if (r > mid)
val = max(val, query(2 * k + 1, l, r));
return val;
}
int query_path(int x, int y)
{
int val = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
swap(x, y);
val = max(val, query(1, pos[top[x]], pos[x]));
x = fa[top[x]];
}
if (dep[x] > dep[y])
swap(x, y);
val = max(val, query(1, pos[x] + 1, pos[y]));
return val;
}
int main()
{
int u, v, x;
scanf("%d", &n);
for (int i = 1; i <= n - 1; i++)
{
scanf("%d%d%d", &u, &v, &x);
add_edge(u, v, x);
add_edge(v, u, x);
}
cnt = 0;
dfs(1, 0);
dfs1(1, 0);
dfs2(1, 0);
build(1, 1, n);
while (scanf("%s", opt))
{
if (opt[0] == 'C')
{
scanf("%d%d", &u, &x);
if (dep[edge[2 * u].from] > dep[edge[2 * u].to])
update(1, pos[edge[2 * u].from], x);
else
update(1, pos[edge[2 * u].to], x);
}
else if (opt[0] == 'Q')
{
scanf("%d%d", &u, &v);
if (u == v)
printf("%d\n", 0);
else
printf("%d\n", query_path(u, v));
}
else
break;
}
return 0;
}