数据结构专题-专项训练:树链剖分

1. 前言

本篇博文为树链剖分的算法总结与专题训练。

没有学过树链剖分?

传送门:数据结构专题-学习笔记:树链剖分

树剖作为一种工具,可以有效解决各类树上问题。

需要注意的是,借助数据结构维护重链信息的时候,不一定只是使用线段树,平衡树,分块等等都可以使用。

当然这篇博文都是线段树。

在往下看之前,请先确保学习过可持久化线段树/动态开点线段树。

2. 题单

P3313 [SDOI2014]旅行

树剖板子题。

考虑对这棵树树剖之后,使用线段树来维护和与最大值的信息,但是怎么维护呢?

\(v\) 棵线段树呗!然后在每一棵线段树中维护对应的值,需要的时候就在对应线段树中修改查询。

结果空间复杂度为 \(O(4 \times n^2)\),只听到一声惨叫:“我 MLE 了!”

于是我们需要加一点优化。

学过可持久化线段树的读者就会发现这其实可以使用可持久化线段树。

我们初始时开 \(v\) 棵线段树,但是每棵线段树都只有一个根节点,记做 \(root_i\)

然后需要的时候静态开点即可。

空间复杂度:每次只会增加 \(\log n\) 个节点,则为可持久化线段树的空间复杂度,\(O(20n)\)

代码:

/*
========= Plozia =========
	Author:Plozia
	Problem:P3313 [SDOI2014]旅行
	Date:2021/3/9
========= Plozia =========
*/

#include <bits/stdc++.h>
#define Max(a, b) ((a > b) ? a : b)
using std::vector;
using std::string;

typedef long long LL;
const int MAXN = 1e5 + 10, MAXX = 2e6 + 10;
int n, q, wfir[MAXN], cfir[MAXN], w[MAXN], c[MAXN], root[MAXN], cnt;
int Son[MAXN], Size[MAXN], dep[MAXN], fa[MAXN], Top[MAXN], id[MAXN];
vector <int> Next[MAXN];
struct node
{
	int w, l, r, maxn;
}tree[MAXN + MAXX + 10];

int read()
{
	int sum = 0, fh = 1; char ch = getchar();
	for (; ch < '0' || ch > '9'; ch = getchar()) fh -= (ch == '-') << 1;
	for (; ch >= '0' && ch <= '9'; ch = getchar()) sum = (sum << 3) + (sum << 1) + (ch ^ 48);
	return (fh == 1) ? sum : -sum;
}

namespace Segment_tree
{
	void change(int p, int l, int r, int k, int v, int w)
	{
		if (l == r) {tree[p].w = tree[p].maxn = w; return ;}
		int mid = (l + r) >> 1;
		if (k <= mid)
		{
			if (!tree[p].l) tree[p].l = ++cnt;
			change(tree[p].l, l, mid, k, v, w);
		}
		else
		{
			if (!tree[p].r) tree[p].r = ++cnt;
			change(tree[p].r, mid + 1, r, k, v, w);
		}
		tree[p].w = tree[tree[p].l].w + tree[tree[p].r].w;
		tree[p].maxn = Max(tree[tree[p].l].maxn, tree[tree[p].r].maxn);
	}
	
	int ask_sum(int p, int l1, int r1, int l2, int r2)
	{
		if (l1 >= l2 && r1 <= r2) return tree[p].w;
		int mid = (l1 + r1) >> 1, ans = 0;
		if (l2 <= mid)
		{
			if (tree[p].l) ans += ask_sum(tree[p].l, l1, mid, l2, r2);
		}
		if (r2 > mid)
		{
			if (tree[p].r) ans += ask_sum(tree[p].r, mid + 1, r1, l2, r2);
		}
		return ans;
	}
	
	int ask_max(int p, int l1, int r1, int l2, int r2)
	{
		if (l1 >= l2 && r1 <= r2) return tree[p].maxn;
		int mid = (l1 + r1) >> 1, ans = 0;
		if (l2 <= mid)
		{
			if (tree[p].l) ans = std::max(ans, ask_max(tree[p].l, l1, mid, l2, r2));
		}
		if (r2 > mid)
		{
			if (tree[p].r) ans = std::max(ans, ask_max(tree[p].r, mid + 1, r1, l2, r2));
		}
		return ans;
	}
}

void dfs1(int now, int father, int depth)
{
	dep[now] = depth; fa[now] = father; Size[now] = 1;
	for (int i = 0; i < Next[now].size(); ++i)
	{
		int u = Next[now][i];
		if (u == father) continue;
		dfs1(u, now, depth + 1);
		Size[now] += Size[u];
		if (Size[u] > Size[Son[now]]) Son[now] = u;
	}
}

void dfs2(int now, int top_father)
{
	id[now] = ++cnt; Top[now] = top_father;
	w[cnt] = wfir[now]; c[cnt] = cfir[now];
	if (!Son[now]) return ;
	dfs2(Son[now], top_father);
	for (int i = 0; i < Next[now].size(); ++i)
	{
		int u = Next[now][i];
		if (u == fa[now] || u == Son[now]) continue;
		dfs2(u, u);
	}
}

int main()
{
	n = read(), q = read();
	for (int i = 1; i <= n; ++i) wfir[i] = read(), cfir[i] = read();
	for (int i = 1; i < n; ++i)
	{
		int x = read(), y = read();
		Next[x].push_back(y), Next[y].push_back(x);
	}
	dfs1(1, 1, 1); dfs2(1, 1);
	for (int i = 1; i <= n; ++i) root[i] = i; cnt = n;
	for (int i = 1; i <= n; ++i) Segment_tree::change(root[c[i]], 1, n, i, c[i], w[i]);
	for (int i = 1; i <= q; ++i)
	{
		string str; std::cin >> str;
		if (str == "CC")
		{
			int x = read(), c_ = read();
			Segment_tree::change(root[c[id[x]]], 1, n, id[x], c[id[x]], 0);
			c[id[x]] = c_;
			Segment_tree::change(root[c[id[x]]], 1, n, id[x], c[id[x]], w[id[x]]);
		}
		if (str == "CW")
			int x = read(), w_ = read(); w[id[x]] = w_;
			Segment_tree::change(root[c[id[x]]], 1, n, id[x], c[id[x]], w[id[x]]);
		}
		if (str == "QS")
		{
			int x = read(), y = read(); int ans = 0, c_ = c[id[x]];
			while (Top[x] != Top[y])
			{
				if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
				ans += Segment_tree::ask_sum(root[c_], 1, n, id[Top[x]], id[x]);
				x = fa[Top[x]];
			}
			if (dep[x] > dep[y]) std::swap(x, y);
			ans += Segment_tree::ask_sum(root[c_], 1, n, id[x], id[y]);
			printf("%d\n", ans);
		}
		if (str == "QM")
		{
			int x = read(), y = read(); int ans = 0, c_ = c[id[x]];
			while (Top[x] != Top[y])
			{
				if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
				ans = std::max(ans, Segment_tree::ask_max(root[c_], 1, n, id[Top[x]], id[x]));
				x = fa[Top[x]];
			}
			if (dep[x] > dep[y]) std::swap(x, y);
			ans = std::max(ans, Segment_tree::ask_max(root[c_], 1, n, id[x], id[y]));
			printf("%d\n", ans);
		}
	}
	return 0;
}

P2486 [SDOI2011]染色

这道题是一道细节题。

解题思路还是比较明显的,使用线段树维护一下区间的头元素,尾元素,答案,合并的时候注意一下头尾元素的合并即可。

然后树剖呢?直接剖一下,然后跳就可以了呀,注意临界点的答案合并。

然后开始愉快的码码码,然后……调了 3 个小时。

所以这道题细节到底在哪里呢?

  1. 注意线段树 \(update\) 的时候左儿子与右儿子可能会首尾相同。
  2. 如果你只是写了一个 \(ask\) 函数而且这个函数只返回了区间的答案,那么请注意:
    在树剖的时候当我们完成询问 \([id_{Top_x},id_x]\) 的时候,一定要知道 \(id_{Top_x}\)\(id_{fa_{Top_x}}\) 的颜色是否相同,因为这涉及到答案是否要减一。相同则需要减一,防止后面的询问对这次产生干扰。

代码:

/*
========= Plozia =========
	Author:Plozia
	Problem:P2486 [SDOI2011]染色
	Date:2021/3/9
========= Plozia =========
*/

#include <bits/stdc++.h>
using std::vector;

typedef long long LL;
const int MAXN = 1e5 + 10;
int n, m, wfir[MAXN];
int Size[MAXN], Son[MAXN], dep[MAXN], fa[MAXN], Top[MAXN], id[MAXN], w[MAXN], cnt;
vector <int> Next[MAXN];
struct node
{
	int firnum, lasnum;
	int sum, l, r, add;
	#define l(p) tree[p].l
	#define r(p) tree[p].r
	#define s(p) tree[p].sum
	#define a(p) tree[p].add
	#define fir(p) tree[p].firnum
	#define las(p) tree[p].lasnum
}tree[MAXN << 2];

int read()
{
	int sum = 0, fh = 1; char ch = getchar();
	for (; ch < '0' || ch > '9'; ch = getchar()) fh -=  (ch == '-') << 1;
	for (; ch >= '0' && ch <= '9'; ch = getchar()) sum = (sum << 3) + (sum << 1) + (ch ^ 48);
	return (fh == 1) ? sum : -sum;
}

namespace Segment_tree
{
	void build(int p, int l, int r)
	{
		l(p) = l, r(p) = r;
		if (l == r) {s(p) = 1; fir(p) = las(p) = w[l]; return ;}
		int mid = (l + r) >> 1;
		build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r);
		s(p) = s(p << 1) + s(p << 1 | 1);
		if (las(p << 1) == fir(p << 1 | 1)) --s(p);
		fir(p) = fir(p << 1), las(p) = las(p << 1 | 1);
	}
	
	void spread(int p)
	{
		if (a(p))
		{
			s(p << 1) = s(p << 1 | 1) = 1; 
			a(p << 1) = a(p << 1 | 1) = a(p);
			fir(p << 1) = fir(p << 1 | 1) = las(p << 1) = las(p << 1 | 1) = a(p);
			a(p) = 0;
		}
	}
	
	void change(int p, int l, int r, int c)
	{
		if (l(p) >= l && r(p) <= r) {s(p) = 1; a(p) = fir(p) = las(p) = c; return ;}
		spread(p);
		int mid = (l(p) + r(p)) >> 1;
		if (l <= mid) change(p << 1, l, r, c);
		if (r > mid) change(p << 1 | 1, l, r, c);
		s(p) = s(p << 1) + s(p << 1 | 1);
		if (las(p << 1) == fir(p << 1 | 1)) --s(p);
		fir(p) = fir(p << 1), las(p) = las(p << 1 | 1);
	}
	
	int ask(int p, int l, int r)
	{
		if (l(p) >= l && r(p) <= r) return s(p);
		spread(p);
		int mid = (l(p) + r(p)) >> 1, ans = 0;
		if (l <= mid && r > mid)
		{
			ans = ask(p << 1, l, r) + ask(p << 1 | 1, l, r);
			if (las(p << 1) == fir(p << 1 | 1)) --ans;
		}
		else if (l <= mid) ans = ask(p << 1, l, r);
		else if (r > mid) ans = ask(p << 1 | 1, l, r);
		return ans;
	}
	int ask2(int p, int k)
	{
		if (l(p) == r(p) && r(p) == k) return fir(p);
		spread(p);
		int mid = (l(p) + r(p)) >> 1;
		if (k <= mid) return ask2(p << 1, k);
		else return ask2(p << 1 | 1, k);
	}
}

void dfs1(int now, int father, int depth)
{
	dep[now] = depth; fa[now] = father; Size[now] = 1;
	for (int i = 0; i < Next[now].size(); ++i)
	{
		int u = Next[now][i];
		if (u == father) continue;
		dfs1(u, now, depth + 1);
		Size[now] += Size[u];
		if (Size[u] > Size[Son[now]]) Son[now] = u; 
	}
}

void dfs2(int now, int top_father)
{
	id[now] = ++cnt; Top[now] = top_father; w[cnt] = wfir[now];
	if (!Son[now]) return ;
	dfs2(Son[now], top_father);
	for (int i = 0; i < Next[now].size(); ++i)
	{
		int u = Next[now][i];
		if (u == fa[now] || u == Son[now]) continue;
		dfs2(u, u);
	}
}

int main()
{
	n = read(), m = read();
	for (int i = 1; i <= n; ++i) wfir[i] = read();
	for (int i = 1; i < n; ++i)
	{
		int x = read(), y = read();
		Next[x].push_back(y), Next[y].push_back(x);
	}
	dfs1(1, 1, 1); dfs2(1, 1); Segment_tree::build(1, 1, n);
	for (int i = 1; i <= m; ++i)
	{
		char ch; std::cin >> ch;
		if (ch == 'C')
		{
			int a = read(), b = read(), c = read();
			while (Top[a] != Top[b])
			{
				if (dep[Top[a]] < dep[Top[b]]) std::swap(a, b);
				Segment_tree::change(1, id[Top[a]], id[a], c);
				a = fa[Top[a]];
			}
			if (dep[a] > dep[b]) std::swap(a, b);
			Segment_tree::change(1, id[a], id[b], c);
		}
		if (ch == 'Q')
		{
			int a = read(), b = read(), ans = 0;
			while (Top[a] != Top[b])
			{
				if (dep[Top[a]] < dep[Top[b]]) std::swap(a, b);
				ans += Segment_tree::ask(1, id[Top[a]], id[a]);
				if (Segment_tree::ask2(1, id[Top[a]]) == Segment_tree::ask2(1, id[fa[Top[a]]])) --ans;
				a = fa[Top[a]];
			}
			if (dep[a] > dep[b]) std::swap(a, b);
			ans += Segment_tree::ask(1, id[a], id[b]);
			printf("%d\n", ans);
		}
	}
	return 0;
}

P1505 [国家集训队]旅游

也是一道树剖题。

这道题首先需要『边权转点权』。

边权转点权的方式如下:

对于第 \(i\) 条边 \((u,v,w)\),我们取深度较大的这个点,假设为 \(x\),则 \(x\) 的点权为 \(w\)

这样,除根节点之外,每一个点都均匀有一个点权。

然后就可以愉快的树剖啦!

线段树需要注意的是,一个区间取反两次就是没有取反,代码中我采用的是异或的性质来处理。

还有,当 \(x,y\) 跳完,在同一条重链的时候,需要特判一下 \(x\) 是否等于 \(y\),因为 \(x,y\) 中深度较小的点记录的点权是不算在路径上的(为其与父节点的路径长度),需要过滤。

细节还是很多的,代码量也很大。

代码:

/*
========= Plozia =========
	Author:Plozia
	Problem:P1505 [国家集训队]旅游
	Date:2021/3/12
========= Plozia =========
*/

#include <bits/stdc++.h>
#define Max(a, b) ((a > b) ? a : b)
#define Min(a, b) ((a < b) ? a : b)
using std::vector;
using std::string;

typedef long long LL;
const int MAXN = 2e5 + 10;
int n, m, afir[MAXN];
int cnt, Top[MAXN], id[MAXN], fa[MAXN], dep[MAXN], Size[MAXN], Son[MAXN], a[MAXN];
struct Edge
{
	int x, y, z;
}e[MAXN];
struct node
{
	int l, r, sum, maxn, minn, add;
	#define l(p) tree[p].l
	#define r(p) tree[p].r
	#define s(p) tree[p].sum
	#define maxn(p) tree[p].maxn
	#define minn(p) tree[p].minn
	#define a(p) tree[p].add
}tree[MAXN << 2];
vector <int> Next[MAXN], Num[MAXN];

int read()
{
	int sum = 0, fh = 1; char ch = getchar();
	for (; ch < '0' || ch > '9'; ch = getchar()) fh -= (ch == '-') << 1;
	for (; ch >= '0' && ch <= '9'; ch = getchar()) sum = (sum << 3) + (sum << 1) + (ch ^ 48);
	return (fh == 1) ? sum : -sum;
}

namespace Segment_tree
{
	void build(int p, int l, int r)
	{
		l(p) = l, r(p) = r;
		if (l == r) {s(p) = maxn(p) = minn(p) = a[l]; return ;}
		int mid = (l + r) >> 1;
		build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r);
		s(p) = s(p << 1) + s(p << 1 | 1);
		maxn(p) = Max(maxn(p << 1), maxn(p << 1 | 1));
		minn(p) = Min(minn(p << 1), minn(p << 1 | 1));
	}
	
	void spread(int p)
	{
		if (a(p) != 0)
		{
			a(p << 1) ^= 1; a(p << 1 | 1) ^= 1;
			s(p << 1) *= -1; s(p << 1 | 1) *= -1;
			int Maxn, Minn;
			Maxn = maxn(p << 1), Minn = minn(p << 1);
			minn(p << 1) = -Maxn, maxn(p << 1) = -Minn;
			Maxn = maxn(p << 1 | 1), Minn = minn(p << 1 | 1);
			minn(p << 1 | 1) = -Maxn, maxn(p << 1 | 1) = -Minn;
			a(p) = 0;
		}
	}
	
	void change_1(int p, int k, int w)
	{
		if (l(p) == r(p) && l(p) == k) {s(p) = maxn(p) = minn(p) = w; return ;}
		spread(p);
		int mid = (l(p) + r(p)) >> 1;
		if (k <= mid) change_1(p << 1, k, w);
		else change_1(p << 1 | 1, k, w);
		s(p) = s(p << 1) + s(p << 1 | 1);
		maxn(p) = Max(maxn(p << 1), maxn(p << 1 | 1));
		minn(p) = Min(minn(p << 1), minn(p << 1 | 1));
	}
	
	void change_2(int p, int l, int r)
	{
		if (l(p) >= l && r(p) <= r)
		{
			a(p) ^= 1; s(p) *= -1;
			int fir = maxn(p), sec = minn(p);
			minn(p) = fir * -1, maxn(p) = sec * -1; return ;
		}
		spread(p); int mid = (l(p) + r(p)) >> 1;
		if (l <= mid) change_2(p << 1, l, r);
		if (r > mid) change_2(p << 1 | 1, l, r);
		s(p) = s(p << 1) + s(p << 1 | 1);
		maxn(p) = Max(maxn(p << 1), maxn(p << 1 | 1));
		minn(p) = Min(minn(p << 1), minn(p << 1 | 1));
	}
	
	int ask_sum(int p, int l, int r)
	{
		if (l(p) >= l && r(p) <= r) return s(p);
		spread(p); int mid = (l(p) + r(p)) >> 1, ans = 0;
		if (l <= mid) ans += ask_sum(p << 1, l, r);
		if (r > mid) ans += ask_sum(p << 1 | 1, l, r);
		return ans;
	}
	
	int ask_maxn(int p, int l, int r)
	{
		if (l(p) >= l && r(p) <= r) return maxn(p);
		spread(p); int mid = (l(p) + r(p)) >> 1, ans = -0x7f7f7f7f, tmp = -0x7f7f7f7f;
		if (l <= mid) {tmp = ask_maxn(p << 1, l, r); ans = Max(ans, tmp);}
		if (r > mid) {tmp = ask_maxn(p << 1 | 1, l, r); ans = Max(ans, tmp);}
		return ans;
	}
	
	int ask_minn(int p, int l, int r)
	{
		if (l(p) >= l && r(p) <= r) return minn(p);
		spread(p); int mid = (l(p) + r(p)) >> 1, ans = 0x7f7f7f7f, tmp = 0x7f7f7f7f;
		if (l <= mid) {tmp = ask_minn(p << 1, l, r); ans = Min(ans, tmp);}
		if (r > mid) {tmp = ask_minn(p << 1 | 1, l, r); ans = Min(ans, tmp);}
		return ans;
	}
}

void dfs1(int now, int father, int depth)
{
	dep[now] = depth; fa[now] = father; Size[now] = 1;
	for (int i = 0; i < Next[now].size(); ++i)
	{
		int u = Next[now][i];
		if (u == father) continue;
		dfs1(u, now, depth + 1);
		Size[now] += Size[u];
		if (Size[u] > Size[Son[now]]) Son[now] = u;
	}
}

void dfs2(int now, int top_father)
{
	Top[now] = top_father; id[now] = ++cnt; a[cnt] = afir[now];
	if (!Son[now]) return ; dfs2(Son[now], top_father);
	for (int i = 0; i < Next[now].size(); ++i)
	{
		int u = Next[now][i];
		if (u == fa[now] || u == Son[now]) continue ;
		dfs2(u, u);
	}
}

int main()
{
	n = read();
	for (int i = 1; i < n; ++i)
	{
		int x = read() + 1, y = read() + 1, z = read();
		e[i] = (Edge){x, y, z};
		Next[x].push_back(y), Next[y].push_back(x);
		Num[x].push_back(z), Num[y].push_back(z);
	}
	dfs1(1, 1, 1);
	for (int i = 1; i < n; ++i)
	{
		int x = e[i].x, y = e[i].y, z = e[i].z;
		if (dep[x] > dep[y]) afir[x] = z;
		else afir[y] = z;
	}//边权转点权
	dfs2(1, 1); Segment_tree::build(1, 1, n);
	m = read();
	for (int i = 1; i <= m; ++i)
	{
		string str; std::cin >> str;
		if (str == "C")
		{
			int k = read(), w = read();
			Segment_tree::change_1(1, id[(dep[e[k].x] > dep[e[k].y]) ? e[k].x : e[k].y], w);
		}
		if (str == "N")
		{
			int x = read() + 1, y = read() + 1;
			while (Top[x] != Top[y])
			{
				if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
				Segment_tree::change_2(1, id[Top[x]], id[x]);
				x = fa[Top[x]];
			}
			if (dep[x] > dep[y]) std::swap(x, y);
			if (x != y) Segment_tree::change_2(1, id[x] + 1, id[y]);
		}
		if (str == "SUM")
		{
			int x = read() + 1, y = read() + 1, ans = 0;
			while (Top[x] != Top[y])
			{
				if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
				ans += Segment_tree::ask_sum(1, id[Top[x]], id[x]);
				x = fa[Top[x]];
			}
			if (dep[x] > dep[y]) std::swap(x, y);
			if (x != y) ans += Segment_tree::ask_sum(1, id[x] + 1, id[y]);
			printf("%d\n", ans);
		}
		if (str == "MAX")
		{
			int x = read() + 1, y = read() + 1, ans = -0x7f7f7f7f, tmp = -0x7f7f7f7f;
			while (Top[x] != Top[y])
			{
				if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
				tmp = Segment_tree::ask_maxn(1, id[Top[x]], id[x]); ans = Max(ans, tmp);
				x = fa[Top[x]];
			}
			if (dep[x] > dep[y]) std::swap(x, y);
			if (x != y) {tmp = Segment_tree::ask_maxn(1, id[x] + 1, id[y]); ans = Max(ans, tmp);}
			printf("%d\n", (ans == -0x7f7f7f7f) ? 0 : ans);
		}
		if (str == "MIN")
		{
			int x = read() + 1, y = read() + 1, ans = 0x7f7f7f7f, tmp = 0x7f7f7f7f;
			while (Top[x] != Top[y])
			{
				if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
				tmp = Segment_tree::ask_minn(1, id[Top[x]], id[x]); ans = Min(ans, tmp);
				x = fa[Top[x]];
			}
			if (dep[x] > dep[y]) std::swap(x, y);
			if (x != y) {tmp = Segment_tree::ask_minn(1, id[x] + 1, id[y]); ans = Min(ans, tmp);}
			printf("%d\n", (ans == 0x7f7f7f7f) ? 0 : ans);
		}
	}
	return 0;
}

3. 总结

树剖的题目难点还是在维护重链信息上,树剖本身不是特别难,在面对树剖题目的时候我们可以假设问题是一个序列问题,在确定怎么维护之后套上树剖即可。

posted @ 2022-04-15 19:45  Plozia  阅读(39)  评论(0编辑  收藏  举报