Loading

树链剖分

基本概念

  1. 树链剖分:通过特定方法将一棵树剖分成多条不相交的链,以达到优化暴力时间复杂度的效果,通常使用 轻重链剖分 实现。树链剖分支持四种操作:

    • 修改树中结点 \(x\) 到结点 \(y\) 的最短路径上所有结点的点权

    • 查询树中结点 \(x\) 到结点 \(y\) 的最短路径上所有节点的点权之和

    • 修改树中结点 \(x\) 及其子树的点权

    • 查询树中结点 \(x\) 及其子树的点权之和

  2. 重儿子:假设结点 \(u\) 有若干子结点,其中子树大小最大的结点称为 \(u\) 的重儿子。

  3. 轻儿子:\(u\) 的子结点中,除了重儿子,都是轻儿子。

  4. 重边:连接结点 \(u\) 与其重儿子的边。

  5. 轻边:连接结点 \(u\) 与其轻儿子的边。

  6. 重链:由若干条重边组成的路径。

  7. 轻链:由若干条轻边组成的路径。

值得注意的是一些树链剖分的性质(建议最后阅读此部分内容):

  • 一棵树中可以有多条重链、多条轻链

  • 轻儿子也可以有其重儿子

  • 若一个结点 \(u\) 有儿子,则 \(u\) 有且仅有一个重儿子

  • 叶子结点没有重儿子

  • 每条重链的链首一定是某个结点的轻儿子

  • 轻儿子可以看作是一条自身到自身的长度为 \(1\) 的重链。

  • 树链剖分的过程可以看成是两个结点 交替地 呈 “人”字形往上跳(当然人的两笔不一定相交)

  • 从根节点到叶子节点的轻边条数和重链条数各不超过 \(logn\)

算法思想

算法组成

树链剖分主要使用 \(dfs\) 序和 线段树 来实现。显然,一个结点 \(u\) 及其子树在 \(dfs\) 序中一定相邻。我们可以借助这个性质来进行树链剖分。

树链剖分算法主要由三个部分组成:两个 \(dfs\) 函数和一棵线段树。其中 \(dfs1\) 函数需要完成:

  • 预处理每个结点的父结点 \(fa\)

  • 预处理每个结点的深度 \(dep\)

  • 预处理每个结点的子树大小 \(size\)

  • 预处理每个结点的重儿子 \(son\)

\(dfs2\) 函数需要完成:

  • 预处理每个结点的 \(dfs\)\(id\)

  • 完成 \(dfs\) 序到其对应结点的映射 \(rk\)

  • 记录每个结点所在重链的链首 \(top\)

而线段树则是用来维护 \(dfs\) 序中重链和子树的点权和,在下文会提及。两个 \(dfs\) 函数的实现并不难,读者可以尝试自行完成。此处提出一个要求:\(dfs2\) 函数必须先遍历重儿子,再遍历轻儿子,以保证一条重链中的结点在 \(dfs\) 序中是连续的,原因下文会提及。

操作处理

修改路径

修改路径和查询路径的思想类似于倍增的 \(LCA\)

首先假设待修改的结点 \(x\)\(y\) 不属于同一条重链。如果 \(x\) 的链首的深度大于 \(y\) 的链首的深度,即 \(dep(top_x) \geq dep(top_y)\) ,则令 \(x\) 跳到 \(x\) 的链首的父结点,即 \(x = fa(top_x)\) ,并令线段树中从 \(x\) 的链首到 \(x\) 的区间加上修改的权值 \(w\)。否则,令 \(y\) 跳到 \(y\) 的链首的父结点,并令线段树中 \(y\) 的链首到 \(y\) 的区间加上 \(w\)

显然,如果重复上述过程,\(x\)\(y\) 最终会到达同一条重链。设 \(x\)\(x\)\(y\) 中深度较小的结点,此时,将线段树中从 \(x\)\(y\)的区间加上 \(x\),修改结束。

此时为什么要先遍历重儿子的原因就显而易见了。因为我们需要求出重链的权值和,又因为线段树只能维护连续区间的和,所以我们会先遍历重儿子,令一整条重链在 \(dfs\) 序中连续。

每次修改时间复杂度 \(O(log^2n)\)

void add_path(int x, int y, int z)
{
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		update(1, id[top[x]], id[x], z);
		x = fa[top[x]];
	}
	if (dep[x] > dep[y])
		swap(x, y);
	update(1, id[x], id[y], z);
}

查询路径

查询路径的过程同修改路径基本一致,同样是借助 \(LCA\) 的思想。只不过当 \(dep(top_x) \geq dep(top_y)\) 时,路径权值总和 \(sum\) 就加上线段树中从 \(x\) 的链首开始到 \(x\) 的区间的权值总和,反之同理。

每次查询时间复杂度 \(O(log^2n)\)

int query_path(int x, int y)
{
	int sum = 0;
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		sum = (sum + query(1, id[top[x]], id[x])) % p;
		x = fa[top[x]];
	}
	if (dep[x] > dep[y])
		swap(x, y);
	sum = (sum + query(1, id[x], id[y])) % p;
	return sum % p;
}

修改子树

先明确一个显然的性质:\(dfs\) 序中一个结点 \(u\) 和其子树是连续的,且这个区间的起点为 \(id_u\) ,终点为 \(id_u + size_u - 1\) 。有了这条性质,我们就可以方便地修改:直接令线段树中从 \(id_u\) 开始到 \(id_u + size_u - 1\) 的区间加上修改的权值 \(w\) 即可。

每次修改时间复杂度 \(O(logn)\)

void update(int k, int l, int r, int x)
{
	if (tree[k].l >= l && tree[k].r <= r)
	{
		tree[k].sum = (tree[k].sum + (tree[k].r - tree[k].l + 1) * x) % p;
		tree[k].lazy += x;
		return;
	}
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2;
	if (l <= mid)
		update(2 * k, l, r, x);
	if (r > mid)
		update(2 * k + 1, l, r, x);
	push_up(k);
}

scanf("%d%d", &x, &z);
update(1, id[x], id[x] + size[x] - 1, z);

查询子树

基本思想同上,直接查询线段树中 \(id_u\)\(id_u + size_u - 1\) 的区间总和即可。

int query(int k, int l, int r)
{
	if (tree[k].l >= l && tree[k].r <= r)
		return tree[k].sum;
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2, sum = 0;
	if (l <= mid)
		sum = (sum + query(2 * k, l, r)) % p;
	if (r > mid)
		sum = (sum + query(2 * k + 1, l, r)) % p;
	return sum % p;
}

scanf("%d", &x);
printf("%d\n", query(1, id[x], id[x] + size[x] - 1) % p);

每次查询时间复杂度 \(O(logn)\)

模板

例题链接

#include <cstdio>
#include <algorithm> 
using namespace std;

const int maxn = 1e5 + 5;
const int maxm = 2e5 + 5;

struct Edge
{
	int to, nxt;
} edge[maxm];

struct node
{
	int l, r, lazy, sum;
} tree[5 * maxn];

int n, m, r, p, cnt, tot;
int head[maxn], id[maxn], rk[maxn], son[maxn], top[maxn];
int w[maxn], dep[maxn], fa[maxn], size[maxn];

void add_edge(int u, int v)
{
	cnt++;
	edge[cnt].to = v;
	edge[cnt].nxt = head[u];
	head[u] = cnt;
}

void dfs1(int u, int f)
{
	fa[u] = f;
	dep[u] = dep[f] + 1;
	size[u] = 1;
	for (int i = head[u]; i; i = edge[i].nxt)
	{
		if (edge[i].to != f)
		{
			dfs1(edge[i].to, u);
			size[u] += size[edge[i].to];
			if (size[edge[i].to] > size[son[u]])
				son[u] = edge[i].to;
		}
	}
}

void dfs2(int u, int t)
{
	id[u] = ++tot;
	rk[tot] = u;
	top[u] = t;
	if (son[u])
		dfs2(son[u], t);
	for (int i = head[u]; i; i = edge[i].nxt)
		if (edge[i].to != fa[u] && edge[i].to != son[u])
			dfs2(edge[i].to, edge[i].to);
}

void push_up(int k)
{
	tree[k].sum = (tree[2 * k].sum + tree[2 * k + 1].sum) % p;
}

void push_down(int k)
{
	if (tree[k].l == tree[k].r)
	{
		tree[k].lazy = 0;
		return;
	}
	tree[2 * k].sum = (tree[2 * k].sum + (tree[2 * k].r - tree[2 * k].l + 1) * tree[k].lazy) % p;
	tree[2 * k + 1].sum = (tree[2 * k + 1].sum + (tree[2 * k + 1].r - tree[2 * k + 1].l + 1) * tree[k].lazy) % p;
	tree[2 * k].lazy += tree[k].lazy;
	tree[2 * k + 1].lazy += tree[k].lazy;
	tree[k].lazy = 0;
}

void build(int k, int l, int r)
{
	tree[k].l = l;
	tree[k].r = r;
	if (l == r)
	{
		tree[k].sum = w[rk[l]];
		return;
	}
	int mid = (l + r) / 2;
	build(2 * k, l, mid);
	build(2 * k + 1, mid + 1, r);
	push_up(k); 
}

void update(int k, int l, int r, int x)
{
	if (tree[k].l >= l && tree[k].r <= r)
	{
		tree[k].sum = (tree[k].sum + (tree[k].r - tree[k].l + 1) * x) % p;
		tree[k].lazy += x;
		return;
	}
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2;
	if (l <= mid)
		update(2 * k, l, r, x);
	if (r > mid)
		update(2 * k + 1, l, r, x);
	push_up(k);
}

int query(int k, int l, int r)
{
	if (tree[k].l >= l && tree[k].r <= r)
		return tree[k].sum;
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2, sum = 0;
	if (l <= mid)
		sum = (sum + query(2 * k, l, r)) % p;
	if (r > mid)
		sum = (sum + query(2 * k + 1, l, r)) % p;
	return sum % p;
}

void add_path(int x, int y, int z)
{
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		update(1, id[top[x]], id[x], z);
		x = fa[top[x]];
	}
	if (dep[x] > dep[y])
		swap(x, y);
	update(1, id[x], id[y], z);
}

int query_path(int x, int y)
{
	int sum = 0;
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		sum = (sum + query(1, id[top[x]], id[x])) % p;
		x = fa[top[x]];
	}
	if (dep[x] > dep[y])
		swap(x, y);
	sum = (sum + query(1, id[x], id[y])) % p;
	return sum % p;
}

int main()
{
	int opt, x, y, z;
	scanf("%d%d%d%d", &n, &m, &r, &p);
	for (int i = 1; i <= n; i++)
		scanf("%d", &w[i]);
	for (int i = 1; i <= n - 1; i++)
	{
		scanf("%d%d", &x, &y);
		add_edge(x, y);
		add_edge(y, x);
	}
	dfs1(r, 0);
	dfs2(r, 0);
	build(1, 1, n);
	for (int i = 1; i <= m; i++)
	{
		scanf("%d", &opt);
		if (opt == 1)
		{
			scanf("%d%d%d", &x, &y, &z);
			add_path(x, y, z);
		}
		else if (opt == 2)
		{
			scanf("%d%d", &x, &y);
			printf("%d\n", query_path(x, y) % p);
		}
		else if (opt == 3)
		{
			scanf("%d%d", &x, &z);
			update(1, id[x], id[x] + size[x] - 1, z);
		}
		else
		{
			scanf("%d", &x);
			printf("%d\n", query(1, id[x], id[x] + size[x] - 1) % p);
		}
	}
	return 0;
}

例题选讲

\(LCA\)

给定一棵树,试求树中两结点 \(x, y\) 的最近公共祖先。

\(LCA\) 问题同样可以使用树剖实现,具体流程同普通树剖基本一致:令链首深度大的那个结点跳到其链首的父结点,直到两个结点在同一条重链内为止。此时深度较小的那个结点就是结点 \(x, y\) 的最近公共祖先。原理显然,不再赘述。

int lca(int x, int y)
{
	while (top[x] != top[y])
	{
		if (dep[top[x]] <= dep[top[y]])
			swap(x, y);
		x = fa[top[x]];
	}
	if (dep[x] >= dep[y])
		swap(x, y);
	return x;
}

边权树剖

例题链接

给定一棵有 \(n\) 个结点的树,请写出一份代码维护以下两种操作:

  1. 将第 \(i\) 条边的边权修改为 \(w\)

  2. 查询从 \(u\)\(v\) 的路径上最大的边权

对于第二种操作,若 \(u = v\),则结果为 \(0\)

结点总数 \(n \leq 10^5\),操作总数 \(m \leq 3 \times 10^5\)

这道题与树剖模板的唯一区别在于,传统树剖维护的是点权,而这种树剖维护的是边权。如果按照维护边权的方法来维护,代码的复杂程度和时间效率都会大幅度降低。因此,我们需要考虑把边权转化成点权。

这里利用一个小技巧:对于一个结点 \(i\) 来说,它的点权等于其上方的边的边权。即设 \(i\) 的父结点为 \(u\),则 \(i\) 的点权等于 \(w_{u, i}\)。特殊地,根结点没有点权。

这样,我们就可以通过传统的树剖方法来维护边权。但是,修改和查询操作需要相应地做出改变。查询 \(x\)\(y\) 的路径中最大边权时,不可以将 \(x\)\(y\) 的最近公共祖先也一起比较,这样会导致最近公共祖先上方的边权也考虑进来。

至于在代码中如何避免查询到最近公共祖先,这里利用一个树剖的性质:当 \(x\)\(y\) 处于同一条重链时,深度较小的那一个结点就是查询的原始结点 \(x\)\(y\) 的最近公共祖先。设 \(x\) 是深度较小的结点,则我们查询的区间可以从 \(x\) 的后一个结点开始,这样就不会查询到 \(lca\)

#include <cstdio>
#include <algorithm>
using namespace std;

const int maxn = 1e5 + 5;
const int maxm = 2e5 + 5;

struct Edge
{
	int from, to, nxt, w, id;
} edge[maxm];

struct node
{
	int l, r, val;
} tree[5 * maxn];

int n, cnt;
int head[maxn], pos[maxn], rk[maxn], son[maxn];
int w[maxn], size[maxn], dep[maxn], fa[maxn], top[maxn];
char opt[10];

void add_edge(int u, int v, int w)
{
	cnt++;
	edge[cnt].from = u;
	edge[cnt].to = v;
	edge[cnt].w = w;
	edge[cnt].nxt = head[u];
	head[u] = cnt;
}

void dfs(int u, int f)
{
	for (int i = head[u]; i; i = edge[i].nxt)
	{
		if (edge[i].to != f)
		{
			w[edge[i].to] = edge[i].w;
			dfs(edge[i].to, u);
		}
	}
}

void dfs1(int u, int f)
{
	fa[u] = f;
	dep[u] = dep[f] + 1;
	size[u] = 1;
	for (int i = head[u]; i; i = edge[i].nxt)
	{
		if (edge[i].to != f)
		{
			dfs1(edge[i].to, u);
			size[u] += size[edge[i].to];
			if (size[edge[i].to] > size[son[u]])
				son[u] = edge[i].to;
		}
	}
}

void dfs2(int u, int t)
{
	cnt++;
	pos[u] = cnt;
	rk[cnt] = u;
	top[u] = t;
	if (son[u])
		dfs2(son[u], t);
	for (int i = head[u]; i; i = edge[i].nxt)
		if (edge[i].to != fa[u] && edge[i].to != son[u])
			dfs2(edge[i].to, edge[i].to);
}

void push_up(int k)
{
	tree[k].val = max(tree[2 * k].val, tree[2 * k + 1].val);
}

void build(int k, int l, int r)
{
	tree[k].l = l;
	tree[k].r = r;
	if (l == r)
	{
		tree[k].val = w[rk[l]];
		return;
	}
	int mid = (l + r) / 2;
	build(2 * k, l, mid);
	build(2 * k + 1, mid + 1, r);
	push_up(k);
}

void update(int k, int x, int w)
{
	if (tree[k].l == tree[k].r)
	{
		tree[k].val = w;
		return;
	}
	int mid = (tree[k].l + tree[k].r) / 2;
	if (x <= mid)
		update(2 * k, x, w);
	else
		update(2 * k + 1, x, w);
	push_up(k);
}

int query(int k, int l, int r)
{
	if (tree[k].l >= l && tree[k].r <= r)
		return tree[k].val;
	int mid = (tree[k].l + tree[k].r) / 2, val = 0;
	if (l <= mid)
		val = max(val, query(2 * k, l, r));
	if (r > mid)
		val = max(val, query(2 * k + 1, l, r));
	return val;
}

int query_path(int x, int y)
{
	int val = 0;
	while (top[x] != top[y])
	{
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		val = max(val, query(1, pos[top[x]], pos[x]));
		x = fa[top[x]];
	}
	if (dep[x] > dep[y])
		swap(x, y);
	val = max(val, query(1, pos[x] + 1, pos[y]));
	return val;
}

int main()
{
	int u, v, x;
	scanf("%d", &n);
	for (int i = 1; i <= n - 1; i++)
	{
		scanf("%d%d%d", &u, &v, &x);
		add_edge(u, v, x);
		add_edge(v, u, x);
	}
	cnt = 0;
	dfs(1, 0);
	dfs1(1, 0);
	dfs2(1, 0);
	build(1, 1, n);
	while (scanf("%s", opt))
	{
		if (opt[0] == 'C')
		{
			scanf("%d%d", &u, &x);
			if (dep[edge[2 * u].from] > dep[edge[2 * u].to])
				update(1, pos[edge[2 * u].from], x);
			else
				update(1, pos[edge[2 * u].to], x);
		}
		else if (opt[0] == 'Q')
		{
			scanf("%d%d", &u, &v);
			if (u == v)
				printf("%d\n", 0);
			else
				printf("%d\n", query_path(u, v));
		}
		else
			break;
	}
	return 0;
}
posted @ 2021-07-24 23:22  kymru  阅读(104)  评论(0编辑  收藏  举报