树链剖分

树链剖分

1 基础理论

1.1 基础概念

在树链剖分中,我们将会遇到如下的名词,在此先做以解释:

  • 重儿子:对于一个子节点 u 如果 v 是其儿子,且 v 的子树大小是节点 u 的儿子中最大的,则称 vu 的重儿子。
  • 轻儿子:除了重儿子以外,就是轻儿子。
  • 重链:除顶部以外,其余节点都为重儿子的一条路径。

1.2 树剖的基本思想

树链剖分,就是将树分割成若干条链,然后利用一些数据结构来维护这些链的方法。

树链剖分有多种形式,一般情况下,树链剖分都是指重链剖分。

在重链剖分中,以重链来剖分树(此时将落单节点也看做重链)。

1.3 代码实现

我们首先给出一些定义:

  • fa(x) 表示 x 的父亲
  • dep(x) 表示 x 的深度
  • siz(x) 表示 x 的子树大小
  • son(x) 表示 x 的重儿子
  • top(x)x 所在重链中深度最小的节点,即顶端节点
  • dfn(x)x 的 DFS 序
  • rnk(x) 表示 DFS 序所对应的节点编号,有 rnk(dfn(x))=x

实现树剖时,使用两个 DFS 完成操作。

在第一个 DFS 中,求出 fa(x),dep(x),siz(x),son(x),代码如下:

void dfs1(int p) {
	son[p] = -1;
	siz[p] = 1;
	for(int i = head[p]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(to == fa[p]) continue;
		fa[to] = p;
		dep[to] = dep[p] + 1;
		dfs1(to);
		siz[p] += siz[to];
		if(son[p] == -1 || siz[to] > siz[son[p]]) {
			son[p] = to;
		}
	}
}

第二个 DFS 中,我们要求出 top(x),dfn(x),rnk(x),代码如下:

void dfs2(int p, int rt) {
	top[p] = rt;
	dfn[p] = ++ind;
	rnk[ind] = p;
	if(son[p] == -1) return;
	dfs2(son[p], rt); //注意这个
	for(int i = head[p]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(to != son[p] && to != fa[p]) {
			dfs2(to, to);
		}
	} 
}

注意在第二次 DFS 中,我们是优先遍历重儿子,因为这样可以让重链上的节点的 DFS 序是连续的,方便后面利用数据结构维护。

2 树剖的基本应用

2.1 树剖求 LCA

树剖可以在 O(logn) 复杂度内求出 LCA,常数也较小。

我们考察待求节点 u,v 的情况:

  1. u,v 在同一条重链上,则 lca(u,v) 就是两者中深度较小的。
  2. u,v 不在同一条重链上,我们先求出他们链头结点 top(u)top(v),然后每一次将深度更大的 top 向父亲跳,直到 u,v 在同一条重链上。

代码如下:

#include <bits/stdc++.h>

using namespace std;

typedef long long LL ;
const int Maxn = 7e5 + 5;

int head[Maxn], edgenum;
struct node{
	int nxt, to;
}edge[Maxn];

void add(int from, int to) {
	edge[++edgenum].nxt = head[from];
	edge[edgenum].to = to;
	head[from] = edgenum;
}

int fa[Maxn], dep[Maxn], siz[Maxn], son[Maxn], top[Maxn], dfn[Maxn], rnk[Maxn];
int ind;

void dfs1(int p) {
	son[p] = -1;
	siz[p] = 1;
	for(int i = head[p]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(to == fa[p]) continue;
		fa[to] = p;
		dep[to] = dep[p] + 1;
		dfs1(to);
		siz[p] += siz[to];
		if(son[p] == -1 || siz[to] > siz[son[p]]) {
			son[p] = to;
		}
	}
}

void dfs2(int p, int rt) {
	top[p] = rt;
	dfn[p] = ++ind;
	rnk[ind] = p;
	if(son[p] == -1) return;
	dfs2(son[p], rt); //注意这个
	for(int i = head[p]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(to != son[p] && to != fa[p]) {
			dfs2(to, to);
		}
	} 
}

int lca(int u, int v) {
	while(top[u] != top[v]) {//暴力跳到同一条重链上
		if(dep[top[u]] < dep[top[v]]) {//取深度较大的跳
			swap(u, v);
		}
		u = fa[top[u]];//调到链头父亲
	}
	return dep[u] > dep[v] ? v : u;//返回深度较小的节点
}

int n, m, s;

int main() {
	ios::sync_with_stdio(0);
	cin >> n >> m >> s;
	for(int i = 1; i < n; i++) {
		int u, v;
		cin >> u >> v;
		add(u, v);
		add(v, u); 
	}
	dfs1(s);
	dfs2(s, s);
	for(int i = 1; i <= m; i++) {
		int u, v;
		cin >> u >> v;
		cout << lca(u, v) << '\n';
	}
	return 0;
}

在洛谷上测试,倍增法的耗时为 3.65s,而树剖法的耗时为 2.49s

2.2 树剖与树上操作

树剖的另一经典应用为完成一系列树上操作。

我们直接来看模板题 P3384 【模板】重链剖分/树链剖分 - 洛谷

2.2.1 树链修改与查询

首先,先以每一个节点来建立线段树。

我们发现,操作 1 和操作 2 都是对树链进行修改或查询的。

我们借鉴上面求 LCA 的思路。考察两点 u,v 的关系。

  • u,v 在同一条重链上时,此时由于重链上的编号是连续的,所以答案区间就是 [dfn(u),dfn(v)]
  • u,v 不在同一条重链上时,仿照 LCA 的方法,先求出他们链头结点 top(u)top(v),然后每一次将深度更大的 top 向父亲跳,直到 u,v 在同一条重链上。注意在跳的时候要持续累加链上的和。

查询修改都如上。

代码如下:

int query(int u, int v) {//查询,修改同理
	int sum = 0;
	while(top[u] != top[v]) {
		if(dep[top[u]] < dep[top[v]]) {
			swap(u, v);
		}
		sum += seg.query(1, dfn[top[u]], dfn[u]);//seg 是线段树
		u = fa[top[u]];
	}
	if(dep[u] > dep[v]) {
		swap(u, v);
	}
	sum += seg.query(1, dfn[u], dfn[v]);
	return sum;
}

2.2.2 子树修改与查询

对于一个节点 u,其子树的 DFS 序都在 [dfn(u),dfn(u)+siz(u)1] 之中。

这个结论很明显,因为子树的 DFS 序都是在遍历这个子树时产生的。

2.2.3 参考代码

#include <bits/stdc++.h>

using namespace std;

typedef long long LL ;
const int Maxn = 1e5 + 5;

int head[Maxn << 1], edgenum;
struct node{
	int nxt, to;
}edge[Maxn << 1];

void add(int from, int to) {
	edge[++edgenum].nxt = head[from];
	edge[edgenum].to = to;
	head[from] = edgenum;
}

int n, m, r, mod;
int w[Maxn];

int fa[Maxn], dep[Maxn], siz[Maxn], son[Maxn], top[Maxn], dfn[Maxn], rnk[Maxn];
int ind;

void dfs1(int p) {
	son[p] = -1;
	siz[p] = 1;
	for(int i = head[p]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(to == fa[p]) continue;
		fa[to] = p;
		dep[to] = dep[p] + 1;
		dfs1(to);
		siz[p] += siz[to];
		if(son[p] == -1 || siz[to] > siz[son[p]]) {
			son[p] = to;
		}
	}
}

void dfs2(int p, int rt) {
	top[p] = rt;
	dfn[p] = ++ind;
	rnk[ind] = p;
	if(son[p] == -1) return;
	dfs2(son[p], rt); 
	for(int i = head[p]; i; i = edge[i].nxt) {
		int to = edge[i].to;
		if(to != son[p] && to != fa[p]) {
			dfs2(to, to);
		}
	} 
}

struct segment_tree {
	struct node {
		int l, r, lazy, sum;
	}t[Maxn << 2];
	#define lp (p << 1)
	#define rp (p << 1 | 1)
	void update(int p) {
		t[p].sum = (t[lp].sum + t[rp].sum) % mod;
	}
	void build(int p, int l, int r) {
		t[p].l = l, t[p].r = r;
		if(l == r) {
			t[p].sum = w[rnk[l]];
			return ;
		}
		int mid = (l + r) >> 1;
		build(lp, l, mid);
		build(rp, mid + 1, r);
		update(p);	
	}
	void pushdown(int p) {
		if(t[p].lazy != 0 && t[p].l != t[p].r) {
			(t[lp].lazy += t[p].lazy) %= mod;
			(t[rp].lazy += t[p].lazy) %= mod;
			(t[lp].sum += t[p].lazy * (t[lp].r - t[lp].l + 1)) %= mod;
			(t[rp].sum += t[p].lazy * (t[rp].r - t[rp].l + 1)) %= mod;
			t[p].lazy = 0;
		}
	}
	void change(int p, int l, int r, int x) {
		pushdown(p);
		if(t[p].l == l && t[p].r == r) {
			(t[p].lazy += x) %= mod;
			(t[p].sum += x * (t[p].r - t[p].l + 1)) % mod;
			return ;
		}
		int mid = (t[p].l + t[p].r) >> 1;
		if(r <= mid) {
			change(lp, l, r, x);
		}
		else if(l > mid) {
			change(rp, l, r, x);
		}
		else {
			change(lp, l, mid, x);
			change(rp, mid + 1, r, x);
		}
		update(p);
	}
	int query_sum(int p, int l, int r) {
		pushdown(p);
		if(t[p].l == l && t[p].r == r) {
			return t[p].sum %= mod;
		}
		int mid = (t[p].l + t[p].r) >> 1;
		if(r <= mid) {
			return query_sum(lp, l, r);
		}
		else if(l > mid) {
			return query_sum(rp, l, r);
		}
		else {
			return (query_sum(lp, l, mid) + query_sum(rp, mid + 1, r)) % mod;
		}
		update(p);
	}	
}seg;

int Cchange(int u, int v, int x) {
	while(top[u] != top[v]) {
		if(dep[top[u]] < dep[top[v]]) {
			swap(u, v);
		}
		seg.change(1, dfn[top[u]], dfn[u], x);
		u = fa[top[u]];
	}
	if(dep[u] > dep[v]) {
		swap(u, v);
	}
	seg.change(1, dfn[u], dfn[v], x);
}

int Cquery(int u, int v) {
	int sum = 0;
	while(top[u] != top[v]) {
		if(dep[top[u]] < dep[top[v]]) {
			swap(u, v);
		}
		sum += seg.query_sum(1, dfn[top[u]], dfn[u]);
		sum %= mod;
		u = fa[top[u]];
	}
	if(dep[u] > dep[v]) {
		swap(u, v);
	}
	sum += seg.query_sum(1, dfn[u], dfn[v]);
	sum %= mod;
	return sum;
}

void Schange(int u, int x) {
	seg.change(1, dfn[u], dfn[u] + siz[u] - 1, x);
}

int Squery(int u) {
	return seg.query_sum(1, dfn[u], dfn[u] + siz[u] - 1) % mod;
}

signed main() {
	ios::sync_with_stdio(0);
	cin >> n >> m >> r >> mod;
	for(int i = 1; i <= n; i++) {
		cin >> w[i];
	}
	for(int i = 1; i < n; i++) {
		int u, v;
		cin >> u >> v;
		add(u, v);
		add(v, u);
	}
	dfs1(r);
	dfs2(r, r);
	seg.build(1, 1, n);
	while(m--) {
		int opt, x, y, z;
		cin >> opt >> x;
		if(opt == 1) {
			cin >> y >> z;
			Cchange(x, y, z);
		}
		else if(opt == 2) {
			cin >> y;
			cout << Cquery(x, y) << '\n';
		}
		else if(opt == 3) {
			cin >> y;
			Schange(x, y);
		}
		else {
			cout << Squery(x) << '\n';
		}
	}
	return 0;
}
posted @   UKE_Automation  阅读(22)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示