树链剖分学习笔记
简介
树链剖分,顾名思义,就是把树剖分成链,在链上进行一系列的操作。
下面我们就来学习一下这个算法。
概念
树链剖分引入了很多新的概念:
- 重儿子:一个节点所有的儿子中子树\(size\)最大的儿子。
- 轻儿子:一个节点的儿子中除了重儿子都是轻儿子。
- 重边:一个节点与它的重儿子所组成的边。
- 轻边:一个节点与它的轻儿子组成的边。
- 重链:若干条重边组成的链。
- 轻链:若干条轻边组成的链。
思想
树链剖分经常与线段树相结合进行链上的操作。
因此线段树是必须要掌握的。
树链剖分一开始要进行\(2\)遍\(dfs\)。
第一次\(dfs\)需要记录出一个节点的父亲、节点的深度和节点的重儿子。
第二次\(dfs\)需要对每个节点进行重新标号,按照重儿子优先的顺序遍历;还要记录出节点所在链的顶端;以及当前标号的点的编号。
然后就是线段树的基本操作。
对链进行维护时需要将两端点往上跳,直到它们在同一条剖分好的链上。
代码
这里以ZJOI2008 树的统计为例题讲解一下树链剖分的代码。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cctype>
#include <string>
#define itn int
#define gI gi
using namespace std;
inline int gi()
{
int f = 1, x = 0; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar();}
return f * x;
}
int q, n, m;
int tot, head[100003], nxt[100003], ver[100003];
int dfn[100003], dep[100003], fa[100003];
int top[100003], son[100003], sz[100003];
int pre[100003], tim;
int a[100003];
inline void add(int u, int v)//邻接表存图
{
ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;
}
void dfs1(itn u, int f)//第一次dfs
{
fa[u] = f/*记录父亲*/, sz[u] = 1/*记录子树大小*/, dep[u] = dep[f] + 1;/*标记深度*/
int maxsize = -1;//最大子树大小
for (itn i = head[u]; i; i = nxt[i])//遍历子节点
{
int v = ver[i];
if (v == f) continue;
dfs1(v, u);
sz[u] = sz[u] + sz[v];//计算子树大小
if (sz[v] > maxsize)//当前子树大小超过当前最大的子树大小
{
maxsize = sz[v], son[u] = v;//更新最大子树大小并标记重儿子
}
}
}
void dfs2(int u, int f)
{
dfn[u] = ++tim/*将树重新标号*/, top[u] = f/*记录链顶*/, pre[tim] = u/*重新编号后编号为tim的节点编号*/;
if (son[u]) dfs2(son[u], f);//优先遍历重儿子
for (itn i = head[u]; i; i = nxt[i])
{
int v = ver[i];
if (v == son[u] || v == fa[u]) continue;//处理过就不需要再处理了
dfs2(v, v);//找出下一条链
}
}
/******以下为线段树******/
int sum[400003], maxs[400003];
inline int ls(int u) {return u << 1;}//左儿子
inline int rs(int u) {return (u << 1) | 1;}//右儿子
inline void pushup(int p)//上传标记
{
sum[p] = sum[ls(p)] + sum[rs(p)];//区间和
maxs[p] = max(maxs[ls(p)], maxs[rs(p)]);//区间最大值
}
void build(int l, int r, itn p)//建树
{
if (l == r) {sum[p] = maxs[p] = a[pre[l]];/*注意是pre[l]*/ return;}//子节点
int mid = (l + r) >> 1;
build(l, mid, ls(p)); build(mid + 1, r, rs(p));
pushup(p);//上传节点
}
void update(int x, int y, itn l, int r, int p)//更新节点信息
{
if (l == r) {sum[p] = maxs[p] = y; return;}//找到了要更新的节点
int mid = (l + r) >> 1;
if (x <= mid) update(x, y, l, mid, ls(p));//左区间寻找
else update(x, y, mid + 1, r, rs(p));//右区间寻找
pushup(p);//上传节点
}
itn getmax(int ql, int qr, int l, itn r, int p)//区间最大值查找
{
if (ql <= l && r <= qr) return maxs[p];//当前区间包含于要寻找的区间
itn mid = (l + r) >> 1, ans = -1000000000;
if (ql <= mid) ans = max(ans, getmax(ql, qr, l, mid, ls(p)));//向左寻找最大值
if (qr > mid) ans = max(ans, getmax(ql, qr, mid + 1, r, rs(p)));//向右寻找最大值
pushup(p);//上传节点
return ans;//返回答案
}
itn getans(int ql, int qr, int l, itn r, int p)//区间和查找,与区间最大值查找没有什么区别
{
if (ql <= l && r <= qr) return sum[p];
itn mid = (l + r) >> 1, ans = 0;
if (ql <= mid) ans = ans + getans(ql, qr, l, mid, ls(p));
if (qr > mid) ans = ans + getans(ql, qr, mid + 1, r, rs(p));
pushup(p);
return ans;
}
/******以上为线段树******/
inline int qmax(int l, itn r)//查找路径上最大值
{
itn ans = -1000000000;
while (top[l] != top[r])//不在同一条链上
{
if (dep[top[l]] < dep[top[r]]) swap(l, r);//找链顶深度大的节点
ans = max(ans, getmax(dfn[top[l]], dfn[l], 1, n, 1));//更新最大值
l = fa[top[l]];//跳到当前链顶的父亲
}
if (dep[l] > dep[r]) swap(l, r);//要满足左端点深度小
ans = max(ans, getmax(dfn[l], dfn[r], 1, n, 1));//更新答案
return ans;//返回
}
inline int qsum(int l, itn r)//求路径权值和,与查找最大值同理
{
itn ans = 0;
while (top[l] != top[r])
{
if (dep[top[l]] < dep[top[r]]) swap(l, r);
ans = ans + getans(dfn[top[l]], dfn[l], 1, n, 1);
l = fa[top[l]];
}
if (dep[l] < dep[r]) swap(l, r);
ans = ans + getans(dfn[r], dfn[l], 1, n, 1);
return ans;
}
int main()
{
n = gi();
for (int i = 1; i < n; i+=1)
{
int u = gI(), v = gI();
add(u, v), add(v, u);
}
for (int i = 1; i <= n; i+=1) a[i] = gi();
dfs1(1, -1); dfs2(1, 1); build(1, n, 1);//预处理
q = gi();
while (q--)
{
char s[10];
scanf("%s", s);
int u = gi(), v = gi();
if (s[1] == 'M') printf("%d\n", qmax(u, v));//区间最大值查找
else if (s[1] == 'S') printf("%d\n", qsum(u, v));//求区间和
else update(dfn[u], v, 1, n, 1);//更新节点
}
return 0;
}
应用
树链剖分求\(\texttt{LCA}\)。
代码如下(以洛谷模板为例):
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cctype>
#define itn int
#define gI gi
using namespace std;
inline int gi()
{
int f = 1, x = 0; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar();}
return f * x;
}
int n, m, rt, dfn[500003], dep[500003], fa[500003], sz[500003], son[500003], pre[500003], top[500003];
int tot, head[2000003], nxt[2000003], ver[2000003];
inline void add(itn u, int v)
{
ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;
}
void dfs1(int u, int f)
{
fa[u] = f, dep[u] = dep[f] + 1, sz[u] = 1;
for (itn i = head[u]; i; i = nxt[i])
{
int v = ver[i];
if (v == f) continue;
dfs1(v, u);
sz[u] = sz[u] + sz[v];
if (sz[v] > sz[son[u]]) son[u] = v;
}
}
int tim;
void dfs2(itn u, int f)
{
top[u] = f, dfn[u] = ++tim, pre[tim] = u;
if (!son[u]) return;
dfs2(son[u], f);
for (int i = head[u]; i; i = nxt[i])
{
int v = ver[i];
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
int main()
{
n = gi(), m = gi(), rt = gi();
for (itn i = 1; i < n; i+=1)
{
int u = gi(), v = gi();
add(u, v), add(v, u);
}
dfs1(rt, rt);
dfs2(rt, rt);
while (m--)
{
int u = gi(), v = gi();
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]]) swap(u, v);
u = fa[top[u]];
}
if (dep[u] < dep[v]) printf("%lld\n", u);
else printf("%lld\n", v);
}
return 0;
}
总结
理解一个算法的思想很重要。
代码要熟练地打出来才算真正理解。
记录一下我踩过的坑:
-
建树时把\(pre[l]\)写成了\(l\);
-
跳端点时没有注意左端点编号小于右端点编号;
-
子树\(size\)初始化成\(0\);
-
求\(\texttt{LCA}\)时把
<
写成>
。
就这样吧~