浅谈树链剖分

树链剖分#

定理#

  • 重儿子:一个节点所有儿子中,子树大小最大的儿子即为重儿子,如有多个,任取一个即可。
  • 轻儿子:除了重儿子外的所有儿子。
  • 重边:父节点 重儿子的边。
  • 重链:由重边构成的极大链。
    如以下图。

过程#

dfs 序:优先遍历重儿子,这样就可以保证重链上所有点的编号连续。

如下图,蓝色数字即为求完 dfs 序后所有点的编号。

求完 dfs 序即将树转化成序列。

定理:树中任意一条路径均可拆分成小于等于 logn 条重链,即可拆分成小于等于 logn 连续区间。

将一条路径拆分成若干条条重链#

这个过程类似于倍增求 lca

假设求 x,y 的若干条重链。

如果 fx>fy 则先将 x 跳到该节点所在重链的顶部再走到他的父节点上。

如果 fy>fx 则先将 y 跳到该节点所在重链的顶部在走到他的父节点上。

其中 fi 表示节点 i 所在重链顶端的深度,即该节点在树的第几层。

最后一定会走到同一条重链上。

以上操作可以用线段树/分块/Splay 来维护。

例题#

Preface#

题目传送门

Solution#

  • 操作 12:即用前述的树链剖分的思想。
  • 操作 34:即把 dfs 序的一段连续区间求和或修改。

维护就与此题类似。

Code#

#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <cmath>
#include <sstream>
#include <set>
#include <unordered_set>
#include <map>
#include <unordered_map>

#define x first
#define y second
#define IOS ios::sync_with_stdio(false)
#define cit cin.tie(0)
#define cot cout.tie(0)

using namespace std;

typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;

const int N = 100010, M = 200010, MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const LL LLINF = 0x3f3f3f3f3f3f3f3f;
const double eps = 1e-8;

int n, m, root, mod;
int w[N], h[N], e[M], ne[M], idx;
int id[N], nw[N], cnt;
int dep[N], sz[N], top[N], fa[N], son[N];
struct Node
{
	int l, r;
	LL add, sum;
}tr[N * 4];

void add(int a, int b)
{
	e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

void dfs1(int u, int father, int depth)
{
	dep[u] = depth, fa[u] = father, sz[u] = 1;
	for (int i = h[u]; ~i; i = ne[i])
	{
		int j = e[i];
		if (j == father) continue;
		dfs1(j, u, depth + 1);
		sz[u] += sz[j];
		if (sz[son[u]] < sz[j]) son[u] = j;
	}
}

void dfs2(int u, int t)
{
	id[u] = ++ cnt, nw[cnt] = w[u], top[u] = t;
	if (!son[u]) return;
	dfs2(son[u], t);
	for (int i = h[u]; ~i; i = ne[i])
	{
		int j = e[i];
		if (j == fa[u] || j == son[u]) continue;
		dfs2(j, j);
	}
}

void pushup(int u)
{
	tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % mod;
}

void pushdown(int u)
{
	auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
	left.add = (left.add + root.add) % mod, left.sum = (left.sum + (left.r - left.l + 1ll) * root.add) % mod;
	right.add = (right.add + root.add) % mod, right.sum = (right.sum + (right.r - right.l + 1ll) * root.add) % mod;
	root.add = 0;
}

void build(int u, int l, int r)
{
	if (l == r) tr[u] = {l, r, 0, nw[r]};
	else
	{
		tr[u] = {l, r};
		int mid = l + r >> 1;
		build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
		pushup(u);
	}
}

void modify(int u, int l, int r, int d)
{
	if (tr[u].l >= l && tr[u].r <= r)
	{
		tr[u].add = (tr[u].add + d) % mod;
		tr[u].sum = (tr[u].sum + (tr[u].r - tr[u].l + 1ll) * d) % mod;
	}
	else
	{
		pushdown(u);
		int mid = tr[u].l + tr[u].r >> 1;
		if (l <= mid) modify(u << 1, l, r, d);
		if (r > mid) modify(u << 1 | 1, l, r, d);
		pushup(u);
	}
}

LL query(int u, int l, int r)
{
	if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
	
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	LL res = 0;
	if (l <= mid) res = (res + query(u << 1, l, r)) % mod;
	if (r > mid) res = (res + query(u << 1 | 1, l, r)) % mod;
	return res;
}

void updata1(int u, int v, int k)
{
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		modify(1, id[top[u]], id[u], k);
		u = fa[top[u]];
	}
	if (dep[u] < dep[v]) swap(u, v);
	modify(1, id[v], id[u], k);
}

LL query1(int u, int v)
{
	LL res = 0;
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		res = (res + query(1, id[top[u]], id[u])) % mod;
		u = fa[top[u]];
	}
	if (dep[u] < dep[v]) swap(u, v);
	res = (res + query(1, id[v], id[u])) % mod;
	return res;
}

void updata2(int u, int k)
{
	modify(1, id[u], id[u] + sz[u] - 1, k);
}

LL query2(int u)
{
	return query(1, id[u], id[u] + sz[u] - 1);
}

void solve()
{
	cin >> n >> m >> root >> mod;
	for (int i = 1; i <= n; i ++ ) cin >> w[i];
	
	memset(h, -1, sizeof h);
	for (int i = 1; i < n; i ++ )
	{
		int a, b;
		cin >> a >> b;
		add(a, b), add(b, a);
	}
	
	dfs1(root, -1, 1);
	dfs2(root, root);
	build(1, 1, n);

	while (m -- )
	{
		int op, u, v, k;
		cin >> op >> u;
		if (op == 1)
		{
			cin >> v >> k;
			updata1(u, v, k);
		}
		else if (op == 2)
		{
			cin >> v;
			cout << query1(u, v) << endl;
		}
		else if (op == 3)
		{
			cin >> k;
			updata2(u, k);
		}
		else cout << query2(u) << endl;
	}
}

int main()
{
	IOS;
	cit, cot;
	int T = 1;
//	cin >> T;
	while (T -- ) solve();
	return 0;
}
posted @   hcywoi  阅读(25)  评论(1编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示
主题色彩