dsu on tree学习笔记

持续更新 但愿不咕

即树上启发式合并 一般用来解决不带修的子树询问问题

具体看一道题吧 CF600E Lomsat gelral

首先重链剖分确实是一个很神奇的东西 我们的 \(dsu\) 也是基于重儿子来的
首先这题非常暴力的一个做法就是对于每个点暴力统计它的子树里的答案 然后清空
这样的复杂度是 \(O(n ^ 2)\)

我们考虑优化 进行 \(dfs\) 的时候 我们先走轻儿子 统计它的答案并清空
轻儿子都走完了再走重儿子 统计它的答案然后不清空
然后我们再走所有的轻儿子并且统计答案

这样的复杂度为什么是对的呢 其实可以用重链剖分来理解
需要用到那个我不会证的结论

对于一个点 它到根节点的路径上只会有不超过 \(logn\) 条重链

我们考虑一个点都什么时候会被统计答案
对于一个轻儿子 我们都需要走这条轻边 然后把它子树里的点都暴力一遍
那么实际上对于一个点而言 它被统计的次数就是它到根节点路径上轻边的数量
又因为连接重链的就是轻边 所以它到根节点的路径上只有不超过 \(logn\) 条轻边 这样它只会被统计不超过 \(logn\)
那我们有 \(n\) 个点 复杂度就是 \(O(nlogn)\)

code:

#include <bits/stdc++.h>
#define ll long long
using namespace std;

const int N = 1e5 + 0721;
int siz[N], son[N]; //求重儿子用
int head[N], nxt[N << 1], to[N << 1], _cnt; //链前
int cnt[N], maxn; //这题要用 
ll sum, ans[N]; 
int cl[N];
int n;

inline void add_edge(int x, int y) {
	to[++_cnt] = y;
	nxt[_cnt] = head[x];
	head[x] = _cnt;
}

void get_Gson(int x, int fa) { //求重儿子 
	siz[x] = 1;
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (y == fa) continue;
		get_Gson(y, x);
		if (siz[y] > siz[son[x]]) son[x] = y;
		siz[x] += siz[y];
	}
}

void init(int x, int fa) { //清空当前子树内的答案 
	--cnt[cl[x]];
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (y == fa) continue;
		init(y, x);
	}
}

void get_ans(int x, int fa, int p) { //遍历并暴力统计答案 
	++cnt[cl[x]];
	if (cnt[cl[x]] > maxn) {
		maxn = cnt[cl[x]];
		sum = cl[x];
	} else if (cnt[cl[x]] == maxn) sum += cl[x];
	
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (y == fa || y == p) continue; //重儿子已经统计过了 所以不走 
		get_ans(y, x, p); 
	}
}

void dfs(int x, int fa) {
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (y == fa || y == son[x]) continue; //先不走重儿子
		dfs(y, x);
		init(y, x);
		sum = maxn = 0; //儿子求完就清空 
	}
	if (son[x]) dfs(son[x], x); //走重儿子 并把答案保留
	get_ans(x, fa, son[x]); //走别的儿子 更新答案 
	ans[x] = sum;
}

int main() {
	scanf("%d", &n);
	for (int i = 1; i <= n; ++i) scanf("%d", &cl[i]);
	for (int i = 1; i < n; ++i) {
		int x, y;
		scanf("%d%d", &x, &y);
		add_edge(x, y);
		add_edge(y, x);
	}
	
	get_Gson(1, 0);
	dfs(1, 0);
	
	for (int i = 1; i <= n; ++i) printf("%lld ", ans[i]);
	
	return 0;
}
posted @ 2023-07-21 20:17  Steven24  阅读(11)  评论(0编辑  收藏  举报