算法学习笔记(16):Link Cut Tree
Link Cut Tree
简称LCT(不是Li Chao Tree), 是一种非常强大的数据结构。
声明
该博客写来很大部分目的是帮助自己理解, 笔者水平有限, 没办法完全原创, 有很多内容源自于OI-wiki,和网上博客, 见谅。
功能
考虑一些问题:
- 树上单点查, 树上路径修改, 这是树上差分可以解决的。
- 那么如果路径查, 路径修呢? 或者再加上子树查, 子树修改, 单点查呢? 树链剖分可以解决。
- 那么再加上可以任意删边连边呢?(保证一直都是一棵树), 这时候就需要LCT了。
所以LCT可以解决:
- 修改两点间路径权值。
- 查询两点间路径权值和。
- 修改某点子树权值。
- 查询某点子树权值和。
- 断边连边。
- 还可以维护一些奇怪信息, 这是因为Splay可以搞区间操作。
算法
实链剖分
这是比较类似重链剖分的, 重链剖分的重链是根据子树的大小来确定的, 而我们的实链是自己钦定的, 随时可以变化。
所以原树会被我们剖成若干个实链, 每条实链由一颗Splay维护。 并且这个Splay是根据深度作为BST的键值的。
所以, 就会有一些美好性质:
- 辅助树由多棵 Splay 组成,每棵 Splay 维护原树中的一条路径,且中序遍历这棵 Splay 得到的点序列,从前到后对应原树「从上到下」的一条路径。
- 原树每个节点与辅助树的 Splay 节点一一对应。
- 辅助树的各棵 Splay 之间并不是独立的。每棵 Splay 的根节点的父亲节点本应是空,但在 LCT 中每棵 Splay 的根节点的父亲节点指向原树中 这条链 的父亲节点(即链最顶端的点的父亲节点)。这类父亲链接与通常 Splay 的父亲链接区别在于儿子认父亲,而父亲不认儿子,对应原树的一条 虚边。因此,每个连通块恰好有一个点的父亲节点为空。
- 由于辅助树的以上性质,我们维护任何操作都不需要维护原树,辅助树可以在任何情况下拿出一个唯一的原树,我们只需要维护辅助树即可。
操作
考虑我们访问树上路径, 这就需要两个很重要的操作。
Access()
void access(int u) {
for (int f = 0; u; f = u, u = fa[u])
splay(u), ch[u][1] = f, maintain(u);
}
表示将节点 \(x\) 到根的路径全部钦定为实边, 这样就可以得到 \(x\) 到根的路径了。 因为只有实边才会被Splay维护, 才可以进行操作。
考虑实现过程, 就是把 \(x\) 翻到所在Splay的根, 然后把他父亲的右儿子指向它(这一步相当于就是把这条虚边钦定为实边), 一直到根, 就OK了。
Makeroot()
void makeroot(int u) {
access(u);
splay(u);
swap(ls, rs);
rev[u] ^= 1;
}
考虑access()
操作只能访问 \(x\) 到根的路径, 但实际上树上路径不一定深度严格递增, 解决办法就是把某一个点钦定为根, 考虑一个性质, 如果把一个树看成一个内向树或者外向树, 那么换根就是把 \(x\) 到原来的根的路径上的边全部反向, 所以我们用Splay的区间翻转, 就可以实现换根。
实现过程: 访问 \(x\) 到根节点的路径, 也就是 access(x)
, 然后将这条路径的Splay的节点 \(x\) 翻到根, 然后打翻转标记就行了。
其他操作就很简单啦, 建议自行yy以加强对这两个核心操作的理解。
模板题
P1501 [国家集训队] Tree II
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5 + 10;
const int mod = 51061;
int n, q;
char op;
struct Splay{
int ch[N][2], fa[N], siz[N], val[N], sum[N], rev[N], add[N], mul[N];
#define ls ch[u][0]
#define rs ch[u][1]
void clear(int u) {
ls = rs = fa[u] = siz[u] = val[u] = sum[u] = rev[u] = add[u] = 0;
mul[u] = 1;
}
int get(int u) { return (ch[fa[u]][1] == u); }
int isroot(int u) {
clear(0);
return ch[fa[u]][0] != u && ch[fa[u]][1] != u;
}
void maintain(int u) {
clear(0);
siz[u] = (siz[ls] + 1 + siz[rs]) % mod;
sum[u] = (sum[ls] + val[u] + sum[rs]) % mod;
}
void pushdown(int u) {
clear(0);
if (mul[u] != 1) {
if (ls)
mul[ls] = (mul[u] * mul[ls]) % mod,
val[ls] = (mul[u] * val[ls]) % mod,
sum[ls] = (mul[u] * sum[ls]) % mod,
add[ls] = (mul[u] * add[ls]) % mod;
if (rs)
mul[rs] = (mul[u] * mul[rs]) % mod,
val[rs] = (mul[u] * val[rs]) % mod,
sum[rs] = (mul[u] * sum[rs]) % mod,
add[rs] = (mul[u] * add[rs]) % mod;
mul[u] = 1;
}
if (add[u]) {
if (ls)
add[ls] = (add[u] + add[ls]) % mod,
val[ls] = (add[u] + val[ls]) % mod,
sum[ls] = (add[u] * siz[ls] % mod + sum[ls]) % mod;
if (rs)
add[rs] = (add[u] + add[rs]) % mod,
val[rs] = (add[u] + val[rs]) % mod,
sum[rs] = (add[u] * siz[rs] % mod + sum[rs]) % mod;
add[u] = 0;
}
if (rev[u]) {
if (ls) rev[ls] ^= 1, swap(ch[ls][0], ch[ls][1]);
if (rs) rev[rs] ^= 1, swap(ch[rs][0], ch[rs][1]);
rev[u] = 0;
}
}
void update(int u) {
if (!isroot(u)) update(fa[u]);
pushdown(u);
}
void rotate(int u) {
int f = fa[u], gf = fa[f], chk = get(u);
fa[u] = gf;
if (!isroot(f)) ch[gf][f == ch[gf][1]] = u;
ch[f][chk] = ch[u][chk ^ 1];
if (ch[u][chk ^ 1]) fa[ch[u][chk ^ 1]] = f;
ch[u][chk ^ 1] = f;
fa[f] = u;
maintain(f);
maintain(u);
maintain(gf);
}
void splay(int u) {
update(u);
for (int f = fa[u]; f = fa[u], !isroot(u); rotate(u)) {
if (!isroot(f)) rotate(get(u) == get(f) ? f : u);
}
}
void access(int u) {
for (int f = 0; u; f = u, u = fa[u])
splay(u), ch[u][1] = f, maintain(u);
}
void makeroot(int u) {
access(u);
splay(u);
swap(ls, rs);
rev[u] ^= 1;
}
int find(int u) {
access(u);
splay(u);
while (ls) u = ls;
splay(u);
return u;
}
void add_edge(int u, int v) {
if (find(u) != find(v))
makeroot(u), fa[u] = v;
}
void modify_add(int u, int v, int z) {
makeroot(u); access(v); splay(v);
val[v] = (val[v] + z) % mod;
sum[v] = (sum[v] + z * siz[v] % mod) % mod;
add[v] = (add[v] + z) % mod;
}
void modify_mul(int u, int v, int z) {
makeroot(u); access(v); splay(v);
val[v] = val[v] * z % mod;
sum[v] = sum[v] * z % mod;
mul[v] = mul[v] * z % mod;
}
void cut(int u, int v) {
makeroot(u); access(v); splay(v);
if (ch[v][0] == u && !rs)
ch[v][0] = fa[u] = 0;
}
void link(int u, int v) {
makeroot(u); splay(u);
if (find(u) != find(v)) fa[u] = v;
}
int query(int u, int v) {
makeroot(u); access(v); splay(v);
return sum[v];
}
}T;
signed main() {
scanf("%lld%lld", &n, &q);
for (int i = 1; i <= n; i++)
T.val[i] = 1, T.maintain(i);
for (int i = 1, u, v; i < n; i++)
scanf("%lld%lld", &u, &v), T.add_edge(u, v);
for (int i = 1, op, u, v, z; i <= q; i++) {
scanf(" %c%lld%lld", &op, &u, &v);
switch(op) {
case '+': scanf("%lld", &z); T.modify_add(u, v, z); break;
case '-': T.cut(u, v); scanf("%lld%lld", &u, &v); T.link(u, v); break;
case '*': scanf("%lld", &z); T.modify_mul(u, v, z); break;
case '/': printf("%lld\n", T.query(u, v)); break;
}
}
return 0;
}
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
struct LCT{
int val[N], sum[N], ch[N][2], fa[N], rev[N];
#define ls ch[u][0]
#define rs ch[u][1]
void clear(int u) { val[u] = sum[u] = ls = rs = fa[u] = rev[u] = 0; }
void maintain(int u) {
clear(0);
sum[u] = sum[ls] ^ val[u] ^ sum[rs];
}
bool get(int u) { return u == ch[fa[u]][1]; }
bool isroot(int u) {
clear(0);
return ch[fa[u]][0] != u && ch[fa[u]][1] != u;
}
void pushdown(int u) {
clear(0); //好习惯
if (rev[u]) {
if (ls) swap(ch[ls][0], ch[ls][1]), rev[ls] ^= 1;
if (rs) swap(ch[rs][0], ch[rs][1]), rev[rs] ^= 1;
rev[u] = 0;
}
}
void update(int u) {
if (!isroot(u)) update(fa[u]);
pushdown(u);
}
void rotate(int u) {
int f = fa[u], gf = fa[f], chk = get(u);
if (!isroot(f)) ch[gf][f == ch[gf][1]] = u;
fa[u] = gf;
if (ch[u][chk ^ 1]) fa[ch[u][chk ^ 1]] = f;
ch[f][chk] = ch[u][chk ^ 1];
ch[u][chk ^ 1] = f;
fa[f] = u;
maintain(f); maintain(u); maintain(gf); //现在f是最下面的点
}
void splay(int u) {
update(u);//记得下传标记
for (int f = fa[u]; f = fa[u], !isroot(u); rotate(u))
if (!isroot(f)) rotate(get(u) == get(f) ? f : u);
}
void access(int u) {
for (int f = 0; u; f = u, u = fa[u])
splay(u), ch[u][1] = f, maintain(u); //变换虚实边, 连右儿子是因为下面的链深度大
}
void makeroot(int u) {
access(u);
splay(u);
rev[u] ^= 1;
swap(ls, rs);
}
int query(int u, int v) {
makeroot(u); access(v); splay(v);
return sum[v];
}
int find(int u) {
access(u); splay(u);
while (ls) u = ls;
splay(u); // 这里再splay一下是为了保证复杂度, 舍小保大, 不翻上去, 要是这里深度比较大, 然后数据疯狂访问这里不就T飞了。
return u;
}
void link(int u, int v) {
if (find(u) == find(v)) return;//这里find一定要放前面, 因为find会把u和v往上翻
makeroot(u); splay(u);
fa[u] = v;
}
void cut(int u, int v) {
if (find(u) != find(v)) return;//一定要放前面, 原因同上
makeroot(u); access(v); splay(v);
if (ch[v][0] == u && !rs) ch[v][0] = fa[u] = 0; //易错 三个条件
//此处不用maintain, 因为u, v 为splay最上方的点, 不会影响答案, link同理
}
void change(int u, int k) {
splay(u);
val[u] = k;
maintain(u);
}
}T;
int n, m;
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
scanf("%d", &T.val[i]), T.maintain(i);
for (int i = 1, op, x, y; i <= m; i++) {
scanf("%d%d%d", &op, &x, &y);
switch(op) {
case 0: printf("%d\n", T.query(x, y)); break;
case 1: T.link(x, y); break;
case 2: T.cut(x, y); break;
case 3: T.change(x, y); break;
}
}
return 0;
}