树链剖分学习笔记

字面意思 把一棵树解剖成若干条链 方便一些操作


重链剖分

大多数树链剖分都指的这个

一些概念:

  • 重儿子:子树大小最大的儿子
  • 轻儿子:除了重儿子以外的儿子
  • 重边:它和重儿子的边
  • 轻边:它和轻儿子的边
  • 重链:把每一条重边连起来形成的链

大概就是 我们把每个节点的子树大小都算出来
然后把每个点和它子树大小最大的儿子连边(其实就是把重链全连起来)
这样就可以把整棵树按重量剖分
如图:

然后我们再进行一次 \(dfs\) 把所有节点按 \(dfs\) 序重新标号 并且 \(dfs\) 儿子的时候优先 \(dfs\) 重儿子
这样它就具有以下特点:

  • 一条重链上的点编号是连续的(因为我们先搜的重儿子)
  • 一棵子树内的点编号是连续的(\(dfs\) 序的性质)
    并且对应的区间应该是 \(\left[id_x, id_x + siz_x - 1\right]\)

然后我们再看一下如何用重链剖分求 \(LCA\)
为此 我们要维护这个点所在重链最顶端那个点
这样我们在查询 \(x\)\(y\)\(LCA\) 时 只有两种情况:

  • 它俩在一条重链上 那 \(LCA\) 就是深度浅的那个
  • 它俩不在一条重链上 那就把 \(top\) 较低的点跳到链顶的父亲 重复此过程直到它俩在一条链上

这个操作是 \(O(logn)\)虽然我不会证

然后呢又因为重链的那个性质 跳链经过的点显然对应 \(dfs\) 序上的一段区间
所以就转化成了一个区间操作问题 用一些数据结构(比如线段树)维护即可

附上P3384 【模板】重链剖分/树链剖分的代码:

#include <bits/stdc++.h>
#define ll long long
#define ls (k << 1)
#define rs (k << 1 | 1)
#define mid (l + r >> 1)
using namespace std;

const int N = 1e5 + 0721;
int n, m, root, mod; 
int a[N], b[N];
int id[N], top[N], fa[N], siz[N], dep[N], son[N];
int head[N], nxt[N << 1], to[N << 1], cnt;
int dfs_clock;

struct tree {
	ll tr[N << 2], lazy[N << 2];
	
	inline void plu(int k, int l, int r, int val) {
		tr[k] = (tr[k] + (r - l + 1) * val) % mod;
		lazy[k] = (lazy[k] + val) % mod;
	}
	
	inline void pushdown(int k, int l, int r) {
		plu(ls, l, mid, lazy[k]);
		plu(rs, mid + 1, r, lazy[k]);
		lazy[k] = 0;
	}
	
	inline void pushup(int k) {
		tr[k] = (tr[ls] + tr[rs]) % mod;
	}
	
	void build(int k, int l, int r) {
		if (l == r) {
			tr[k] = a[l];
			return;
		}
		build(ls, l, mid);
		build(rs, mid + 1, r);
		pushup(k);
	}
	
	void modify(int k, int l, int r, int u, int v, int val) {
		if (u <= l && v >= r) {
			plu(k, l, r, val);
			return;
		}
		if (lazy[k] != 0) pushdown(k, l, r);
		if (u <= mid) modify(ls, l, mid, u, v, val);
		if (v > mid) modify(rs, mid + 1, r, u, v, val);
		pushup(k);
	}
	
	ll query(int k, int l, int r, int u, int v) {
		if (u <= l && v >= r) {
			return tr[k];
		}
		if (lazy[k] != 0) pushdown(k, l, r);
		ll ret = 0;
		if (u <= mid) ret = (ret + query(ls, l, mid, u, v)) % mod;
		if (v > mid) ret = (ret + query(rs, mid + 1, r, u, v)) % mod;
		pushup(k);
		return ret;
	}
} seg;

inline void add_edge(int x, int y) {
	to[++cnt] = y;
	nxt[cnt] = head[x];
	head[x] = cnt;
}

void dfs1(int x, int f) {
	siz[x] = 1;
	int maxn = 0;
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (y == f) continue;
		
		fa[y] = x;
		dep[y] = dep[x] + 1;
		
		dfs1(y, x);
		
		siz[x] += siz[y];
		if (siz[y] > maxn) son[x] = y; //记录重儿子 
	}
}

void dfs2(int x, int topx) {
	id[x] = ++dfs_clock;
	a[dfs_clock] = b[x]; //方便等会建树
	top[x] = topx;
	
	if (son[x] == 0) return; //没有重儿子
	dfs2(son[x], topx); //先搜重儿子
	
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (id[y] == 0) dfs2(y, y); //对轻儿子进行dfs 它自己就是它自己那条重链的头 
	} 
}

void modify1(int x, int y, int val) { //把x到y路径上的都加val 
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]]) swap(x, y); //让top更深的那个往上跳
		seg.modify(1, 1, n, id[top[x]], id[x], val); //把整条链都改了 
		x = fa[top[x]]; 
	}
	if (dep[x] < dep[y]) swap(x, y);
	seg.modify(1, 1, n, id[y], id[x], val); //现在x和y在一条链上 
}

ll query1(int x, int y) { //查询x到y路径上的点权和
	ll ret = 0; 
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]]) swap(x, y);
		ret = (ret + seg.query(1, 1, n, id[top[x]], id[x])) % mod;
		x = fa[top[x]];
	}
	if (dep[x] < dep[y]) swap(x, y);
	ret = (ret + seg.query(1, 1, n, id[y], id[x])) % mod;
	return ret;
}

void modify2(int x, int val) { //x子树内都加val 
	seg.modify(1, 1, n, id[x], id[x] + siz[x] - 1, val);
}

ll query2(int x) { //查询x子树内的点权和 
	return seg.query(1, 1, n, id[x], id[x] + siz[x] - 1);
}

int main() {
	
	scanf("%d%d%d%d", &n, &m, &root, &mod);
	for (int i = 1; i <= n; ++i) {
		scanf("%d", &b[i]);
		b[i] %= mod; 
	} 
	for (int i = 1; i < n; ++i) {
		int x, y;
		scanf("%d%d", &x, &y);
		add_edge(x, y);
		add_edge(y, x);
	}
	
	dep[root] = 1;
	dfs1(root, 0);
	dfs2(root, root);
	seg.build(1, 1, n);
	
	while (m--) {
		int opt, x, y, z;
		scanf("%d", &opt);
		if (opt == 1) {
			scanf("%d%d%d", &x, &y, &z);
			modify1(x, y, z);
		} else if (opt == 2) {
			scanf("%d%d", &x, &y);
			printf("%lld\n", query1(x, y));
		} else if (opt == 3) {
			scanf("%d%d", &x, &y);
			modify2(x, y);
		} else {
			scanf("%d", &x);
			printf("%lld\n", query2(x));
		}
	}
	
	return 0;
}

重链剖分的应用就太多了 这里就不一一枚举了


长链剖分

其实还是字面意思 重链剖分是取子树最大的儿子 而长链剖分是取子树内叶节点最深的儿子
具体剖分方法和重链剖分基本一样 这里就不再赘述
它具有一下几个性质:

  • 所有长链链长之和为点数(显然)
  • 对于一个点 \(x\) 它的 \(k\) 级祖先 \(v\) 所在的长链一定 \(\ge k\)

证明:如果 \(x\)\(v\) 在一条长链上 那么显然
如果 \(x\)\(v\) 不在一条长链上 假设 \(v\) 所在的长链长度 \(<k\) 那么 \(x\)\(v\) 的这条链才是长链 与假设不符

  • 任意节点到达根节点跳跃次数是 \(\sqrt{n}\) 级别(重链是 \(logn\)

下面讲下长链剖分的一些应用

  • 优化与深度有关\(DP\)

CF1009F Dominant Indices

非常经典的一道长链剖分优化 \(DP\)
\(f_{i, j}\) 表示 \(i\)\(j\) 级儿子的数量
显然有 \(f_{i, j} = \sum\limits_{k \in son_i}^{} f_{k, j - 1}\)
但是这个无论是时间还是空间都是 \(O(n ^ 2)\)
我们进一步发现 这个转移式实际上就是把 \(i\) 的儿子与 \(i\) 错开一位放置
所以我们对整棵树进行长链剖分 然后把长儿子的 \(f\) 值直接记录在 \(i\) 身上
对于其它的儿子 我们进行暴力合并 每次合并复杂度就是 \(O(链长)\)
那么全合并的复杂度就是链长之和 即为 \(O(n)\)

code:

#include <bits/stdc++.h>
using namespace std;

const int N = 1e6 + 0721;
int buf[N];
int *f[N], *now = buf;
int head[N], nxt[N << 1], to[N << 1], cnt;
int dep[N], son[N];
int ans[N];
int n;

void add_edge(int x, int y) {
	to[++cnt] = y;
	nxt[cnt] = head[x];
	head[x] = cnt;
}

void dfs1(int x, int fa) {
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (y == fa) continue;
		dfs1(y, x);
		if (dep[y] > dep[son[x]]) son[x] = y;
	}
	dep[x] = dep[son[x]] + 1;
}

void dfs2(int x, int fa) {
	f[x][0] = 1;
	if (son[x] == 0) return;
	
	f[son[x]] = f[x] + 1; //把长儿子的dp数组指在x后一位的地方
	dfs2(son[x], x);
	ans[x] = ans[son[x]] + 1; //把长儿子合并到x上
	
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (y == son[x] || y == fa) continue;
		f[y] = now;
		now += dep[y]; //为y分配长度为dep[y]的内存 
		dfs2(y, x);
		
		for (int i = 1; i <= dep[y]; ++i) { //把其他儿子的信息暴力合并到x上 
			f[x][i] += f[y][i - 1];
			if (f[x][ans[x]] < f[x][i] || (f[x][ans[x]] == f[x][i])) ans[x] = i;
		}
	} 
	if (f[x][ans[x]] == 1) ans[x] = 0;
}

int main() {
	scanf("%d", &n);
	for (int i = 1; i < n; ++i) {
		int x, y;
		scanf("%d%d", &x, &y);
		add_edge(x, y);
		add_edge(y, x);
	}
	
	dfs1(1, 0);
	f[1] = now;
	now += dep[1]; //为根节点分配内存
	dfs2(1, 0);
	for (int i = 1; i <= n; ++i) printf("%d\n", ans[i]); 
	
	return 0;
}
/*
4
1 2
2 3
2 4
*/
posted @ 2023-07-16 21:33  Steven24  阅读(19)  评论(0编辑  收藏  举报