树上启发式合并

Dsu on Tree

又名"树上启发式合并" 可以处理静态的树上子树答案统计问题

可以将暴力做法\(O(n^2)\)压到\(O(nlogn)\)

常数小于线段树合并 但是线段树合并支持历史版本 带修等多种操作

U41492 树上数颜色

模板题

我们在\(dfs\)的时候首先\(dfs\)轻节点 记录对应答案后清空 再\(dfs\)重儿子 不清空

统计当前\(u\)节点的答案时 暴力累加轻儿子的贡献 再累加已经求出的重儿子的贡献即可

清空的时候采取暴力清空方式即可

我们的\(solve\)函数中基本操作是:

  1. 递归\(solve\)轻儿子 统计它们的答案之后清空
  2. 递归\(solve\)重儿子 统计答案后不清空
  3. 暴力统计所有轻儿子的贡献 继承重儿子的答案后记录答案
  4. 如果需要清空贡献 那么全部清空(包括当前节点,重儿子和轻儿子)
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define mid (l+r>>1)
#define eb emplace_back
#define print(x) cout << #x << '=' << x << endl
constexpr int N = 2e5 + 5;
char buf[1<<24] , *p1 , *p2;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
//#define getchar() cin.get() 
int read ()
{
	int x = 0 , f = 1;
	char ch = getchar();
	while ( !isdigit ( ch ) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
	while ( isdigit ( ch ) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
	return x * f;
}

int n , m , col[N] , ans[N] , b[N] , res;

vector<int> e[N];
void add ( int u , int v ) { e[u].eb(v); }

void upd ( int x ) { res += ( ++ b[col[x]] == 1 ); }
void del ( int x ) { res -= ( -- b[col[x]] == 0 ); }

int sz[N] , l[N] , r[N] , rev[N] , timer , son[N];

void dfs ( int u , int f )
{
	l[u] = ++timer , rev[timer] = u , sz[u] = 1;
	for ( auto v : e[u] )
	{
		if ( v == f ) continue; 
		dfs ( v , u );
		sz[u] += sz[v];
		if ( sz[son[u]] < sz[v] ) son[u] = v;
	}
	r[u] = timer;
}

void solve ( int u , int f , int keep )
{
	for ( auto v : e[u] ) if ( v != f && v != son[u] ) solve ( v , u , 0 );
	if ( son[u] ) solve ( son[u] , u , 1 );
	for ( auto v : e[u] ) if ( v != f && v != son[u] ) for ( int i = l[v] ; i <= r[v] ; i ++ ) upd ( rev[i] );
	upd(u) , ans[u] = res;
	if ( !keep ) for ( int i = l[u] ; i <= r[u] ; i ++ ) del ( rev[i] );
}

signed main ()
{
	ios::sync_with_stdio(false);
	cin.tie(0) , cout.tie(0);
	n = read();
	for ( int i = 1 , u , v ; i < n ; i ++ ) u = read() , v = read() , add ( u , v ) , add ( v , u );
	for ( int i = 1 ; i <= n ; i ++ ) col[i] = read();
	dfs ( 1 , 0 ) , solve ( 1 , 0 , 0 );
	m = read();
	for ( int i = 1 ; i <= m ; i ++ ) cout << ans[read()] << endl;
	return 0;
}

Lomsat gelral

板子题

建议在统计贡献的时候用子树的\(dfn\)序来统计 避免了繁琐的递归判断过程(例如下面注释掉的错误\(solve\)函数)

必须注意:统计答案的时候 轻儿子中的重儿子的贡献也是要一并统计进去的

也就是不能简单地用\(son[u]\)来判断是否递归下去

#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define eb emplace_back
#define int long long
#define print(x) cerr << #x << '=' << x << endl
constexpr int N = 2e5 + 5;
char buf[1<<24] , *p1 , *p2;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
//#define getchar() cin.get() 
int read ()
{
	int x = 0 , f = 1;
	char ch = getchar();
	while ( !isdigit ( ch ) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
	while ( isdigit ( ch ) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
	return x * f;
}

int n , m , col[N] , ans[N] , b[N];

int l[N] , r[N] , rev[N] , sz[N] , son[N] , maxx , sum , timer;

vector<int> e[N];
void add ( int u , int v ) { e[u].eb(v); }

void dfs ( int u , int f )
{
	sz[u] = 1 , l[u] = ++timer , rev[timer] = u;
	for ( auto v : e[u] )
	{
		if ( v == f ) continue; 
		dfs ( v , u );
		sz[u] += sz[v];
		if ( sz[son[u]] < sz[v] ) son[u] = v;
	}
	r[u] = timer;
}

void upd ( int u )
{
	b[col[u]] ++;
	if ( b[col[u]] > maxx ) sum = col[u] , maxx = b[col[u]];
	else if ( b[col[u]] == maxx ) sum += col[u];
}

void del ( int u ) { b[col[u]] --; }

//void solve ( int u , int f , int keep )
//{
//	for ( auto v : e[u] ) if ( v != f && v != son[u] ) solve ( v , u , 0 );
//	if ( son[u] ) solve ( son[u] , u , 1 );
//	upd ( u , f ) , ans[u] = sum;
//	if ( !keep ) clear ( u , f ) , sum = 0 , maxx = 0;
//}

void solve ( int u , int f , int keep )
{
	for ( auto v : e[u] ) if ( v != f && v != son[u] ) solve ( v , u , 0 );
	if ( son[u] ) solve ( son[u] , u , 1 );
	for ( auto v : e[u] ) if ( v != f && v != son[u] ) for ( int i = l[v] ; i <= r[v] ; i ++ ) upd(rev[i]);
	upd(u); ans[u] = sum;
	if ( !keep ) { for ( int i = l[u] ; i <= r[u] ; i ++ ) del(rev[i]); sum = 0 , maxx = 0; }
}

signed main ()
{
//	freopen ( "a.in" , "r" , stdin );
	ios::sync_with_stdio(false);
	cin.tie(0) , cout.tie(0);
	n = read();
	for ( int i = 1 ; i <= n ; i ++ ) col[i] = read();
	for ( int i = 1 , u , v ; i < n ; i ++ ) u = read() , v = read() , add ( u , v ) , add ( v , u );
	dfs ( 1 , 0 ) , solve ( 1 , 0 , 0 );
	for ( int i = 1 ; i <= n ; i ++ ) cout << ans[i] << ' ';
	return 0;
}

XOR Tree

运用了启发式合并的思想

可以发现如果维护一个从上到下的树上前缀和\(dis\) 那么存在异或和为\(0\)的一条路径\((u,v)\)当且仅当\(dis[u]\oplus dis[v]\oplus a[lca(u,v)]=0\)

那么我们可以用\(set\)维护每一个子树中的所有\(dis\)值 到子树的根节点\(u\)的时候 遍历所有的\(v\) 在集合\(v\)中找\(dis[u]\oplus a[lca(u,v)]\) 如果存在 那么这棵子树肯定需要清空

用启发式合并可以压到\(O(nlog^2n)\)

#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define eb emplace_back
#define int long long
#define print(x) cerr << #x << '=' << x << endl
constexpr int N = 2e5 + 5;
//char buf[1<<24] , *p1 , *p2;
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
#define getchar() cin.get() 
int read ()
{
	int x = 0 , f = 1;
	char ch = getchar();
	while ( !isdigit ( ch ) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
	while ( isdigit ( ch ) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
	return x * f;
}

set<int> s[N];
vector<int> e[N];
void add ( int u , int v ) { e[u].eb(v); }

int n , a[N] , dis[N] , ans;

void dfs ( int u , int f )
{
	int flag = 0;
	dis[u] = a[u] ^ dis[f] , s[u].insert(dis[u]);
	for ( auto v : e[u] )  
	{
		if ( v == f ) continue;
		dfs ( v , u );
		if ( s[u].size() < s[v].size() ) swap ( s[u] , s[v] );
		for ( auto x : s[v] ) if ( s[u].find(x^a[u]) != s[u].end() ) flag = 1;
		for ( auto x : s[v] ) s[u].insert(x);
	}
	if ( flag ) ans ++ , s[u].clear();
}
signed main ()
{
	ios::sync_with_stdio(false);
	cin.tie(0) , cout.tie(0);
	n = read();
	for ( int i = 1 ; i <= n ; i ++ ) a[i] = read();
	for ( int i = 1 , u , v ; i < n ; i ++ ) u = read() , v = read() , add ( u , v ) , add ( v , u );
	dfs ( 1 , 0 );
	cout << ans << endl;
	return 0;
}

Blood Cousins

\(Dsu\ on\ tree\) 还有线段树合并写法(在另一篇\(blog\)中)

还是要注意\(jump\)函数的写法罢(\(dep\) \(rev\) \(top\)这些数组要分清)

#include <bits/stdc++.h>
using namespace std;
#define mid (l+r>>1)
#define endl '\n'
#define inl inline
#define ls(p) t[p].son[0]
#define rs(p) t[p].son[1]
#define lson ls(p),l,mid
#define rson rs(p),mid+1,r
#define eb emplace_back
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 5;
char buf[1<<24] , *p1 , *p2;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
//#define getchar() cin.get();
int read()
{
	int x = 0 , f = 1;
	char ch = getchar();
	while ( !isdigit(ch) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
	while ( isdigit(ch) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
	return x * f ;
}

int n , m , ans[N] , cnt[N] , l[N] , r[N];

int fa[N] , dep[N] , sz[N] , son[N];
int pos[N] , rev[N] , top[N] , timer;

struct que { int val , id; };
vector<que> q[N];

vector<int> e[N];
inl void add ( int u , int v ) { e[u].eb(v); }

struct LCA
{
	void dfs1 ( int u , int f )
	{
		dep[u] = dep[f] + 1 , fa[u] = f , sz[u] = 1;
		for ( auto v : e[u] )
			if ( v ^ f )
			{
				dfs1 ( v , u );
				sz[u] += sz[v];
				if ( sz[son[u]] < sz[v] ) son[u] = v;//存疑
			}
	}
	void dfs2 ( int u , int tp )
	{
		top[u] = tp , pos[u] = ++timer , rev[timer] = u;
		if ( son[u] ) dfs2 ( son[u] , tp );
		for ( auto v : e[u] ) if ( v != fa[u] && v != son[u] ) dfs2 ( v , v );
	}
	int jump ( int u , int k )
	{
		while ( k > 0 && u != 0 )
		{
			if ( pos[u] - pos[top[u]] + 1 > k ) return rev[pos[u]-k];
			k -= pos[u] - pos[top[u]] + 1;
			u = fa[top[u]];
		}
		return u;
	}
}L;

inl void add ( int x ) { cnt[dep[x]] ++; }
inl void del ( int x ) { cnt[dep[x]] --; } 

void dfs ( int u , int f , int keep )
{
	for ( auto v : e[u] ) if ( v != son[u] && v != f ) dfs ( v , u , 0 );
	if ( son[u] ) dfs ( son[u] , u , 1 );
	for ( auto v : e[u] ) if ( v != son[u] && v != f ) for ( int i = pos[v] ; i <= pos[v] + sz[v] - 1 ; i ++ ) add(rev[i]);
	add(u);
	for ( auto [val,id] : q[u] ) ans[id] = cnt[val];
	if ( !keep ) for ( int i = pos[u] ; i <= pos[u] + sz[u] - 1 ; i ++ ) del(rev[i]);
}

signed main ()
{
	ios::sync_with_stdio(false);
	cin.tie(nullptr) , cout.tie(nullptr);
	n = read();
	for ( int i = 1 , fa ; i <= n ; i ++ ) fa = read() , add ( fa , i ) , add ( i , fa );
	for ( auto v : e[0] ) L.dfs1 ( v , 0 );
	for ( auto v : e[0] ) L.dfs2 ( v , v );
	m = read();
	for ( int i = 1 , x , y ; i <= m ; i ++ )
	{
		x = read() , y = read();
		int lca = L.jump ( x , y );
		if ( lca ) q[lca].eb((que){dep[x],i});
		else ans[i] = 1;
	}
	for ( auto v : e[0] ) dfs ( v , 0 , 0 );
	for ( int i = 1 ; i <= m ; i ++ ) cout << ans[i] - 1 << ' ';
	return 0;
}

P5384 [Cnoi2019] 雪松果树

同上一道题 但是线段树合并会被卡 \(dsu\)可以过

#include <bits/stdc++.h>
using namespace std;
#define mid (l+r>>1)
#define endl '\n'
#define inl inline
#define ls(p) t[p].son[0]
#define rs(p) t[p].son[1]
#define lson ls(p),l,mid
#define rson rs(p),mid+1,r
#define eb emplace_back
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const int N = 1e6 + 5;
char buf[1<<24] , *p1 , *p2;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
//#define getchar() cin.get();
int read()
{
	int x = 0 , f = 1;
	char ch = getchar();
	while ( !isdigit(ch) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
	while ( isdigit(ch) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
	return x * f ;
}

int n , m , ans[N] , cnt[N] , l[N] , r[N];

int fa[N] , dep[N] , sz[N] , son[N];
int pos[N] , rev[N] , top[N] , timer;

struct que { int val , id; };
vector<que> q[N];

vector<int> e[N];
inl void add ( int u , int v ) { e[u].eb(v); }

struct LCA
{
	void dfs1 ( int u , int f )
	{
		dep[u] = dep[f] + 1 , fa[u] = f , sz[u] = 1;
		for ( auto v : e[u] )
			if ( v ^ f )
			{
				dfs1 ( v , u );
				sz[u] += sz[v];
				if ( sz[son[u]] < sz[v] ) son[u] = v;//存疑
			}
	}
	void dfs2 ( int u , int tp )
	{
		top[u] = tp , pos[u] = ++timer , rev[timer] = u;
		if ( son[u] ) dfs2 ( son[u] , tp );
		for ( auto v : e[u] ) if ( v != fa[u] && v != son[u] ) dfs2 ( v , v );
	}
	int jump ( int u , int k )
	{
		while ( k > 0 && u != 0 )
		{
			if ( pos[u] - pos[top[u]] + 1 > k ) return rev[pos[u]-k];
			k -= pos[u] - pos[top[u]] + 1;
			u = fa[top[u]];
		}
		return u;
	}
}L;

inl void add ( int x ) { cnt[dep[x]] ++; }
inl void del ( int x ) { cnt[dep[x]] --; } 

void dfs ( int u , int f , int keep )
{
	for ( auto v : e[u] ) if ( v != son[u] && v != f ) dfs ( v , u , 0 );
	if ( son[u] ) dfs ( son[u] , u , 1 );
	for ( auto v : e[u] ) if ( v != son[u] && v != f ) for ( int i = pos[v] ; i <= pos[v] + sz[v] - 1 ; i ++ ) add(rev[i]);
	add(u);
	for ( auto [val,id] : q[u] ) ans[id] = cnt[val];
	if ( !keep ) for ( int i = pos[u] ; i <= pos[u] + sz[u] - 1 ; i ++ ) del(rev[i]);
}

signed main ()
{
	ios::sync_with_stdio(false);
	cin.tie(nullptr) , cout.tie(nullptr);
	n = read();m = read();
	for ( int i = 2 , fa ; i <= n ; i ++ ) fa = read() , add ( fa , i ) , add ( i , fa );
	L.dfs1 ( 1 , 0 ) , L.dfs2 ( 1 , 1 );
	for ( int i = 1 , x , y ; i <= m ; i ++ )
	{
		x = read() , y = read();
		int lca = L.jump ( x , y );
		if ( lca ) q[lca].eb((que){dep[x],i});
		else ans[i] = 1;
	}
	dfs ( 1 , 0 , 0 );
	for ( int i = 1 ; i <= m ; i ++ ) cout << ans[i] - 1 << ' ';
	return 0;
}

Tree Requests

突然感觉\(dsu\)代码更短常数空间更小细节更少一点... 那以后就用\(dsu\)

线段树合并垃圾回收的时候不能按照正常思路写真的烦人()

对于每一个节点状压维护一个二进制数 表示这个点的字符个数是奇数/偶数个 统计答案的时候如果子树中的字符有两个或以上为奇数的 那么不可以 否则可以

#include <bits/stdc++.h>
using namespace std;
#define mid (l+r>>1)
#define endl '\n'
#define inl inline
#define ls(p) t[p].son[0]
#define rs(p) t[p].son[1]
#define lson ls(p),l,mid
#define rson rs(p),mid+1,r
#define eb emplace_back
const int inf = 0x3f3f3f3f;
const int N = 5e5 + 5;
//char buf[1<<24] , *p1 , *p2;
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
#define getchar() cin.get();
int read()
{
	int x = 0 , f = 1;
	char ch = getchar();
	while ( !isdigit(ch) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
	while ( isdigit(ch) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
	return x * f ;
}

int n , m , dep[N] , timer , rev[N] , son[N] , sz[N] , l[N] , r[N] , a[N];
int cnt[N] , ans[N];

struct que { int val , id; };
vector<que> q[N];

vector<int> e[N];
inl void add ( int u , int v ) { e[u].eb(v); }

void dfs ( int u , int f )
{
	dep[u] = dep[f] + 1 , sz[u] = 1 , l[u] = ++timer , rev[timer] = u;
	for ( int v : e[u] )
		if ( v != f )
		{
			dfs ( v , u );
			sz[u] += sz[v];
			if ( sz[son[u]] < sz[v] ) son[u] = v;
		}
	r[u] = timer;
}

void upd ( int u ) { cnt[dep[u]] ^= ( 1 << a[u] ); }
void update ( int u ) { for ( int i = l[u] ; i <= r[u] ; i ++ ) upd ( rev[i] ); }
	
void solve ( int u , int f , int keep )
{
	for ( int v : e[u] ) if ( v != f && v != son[u] ) solve ( v , u , 0 );
	if ( son[u] ) solve ( son[u] , u , 1 );
	for ( int v : e[u] ) if ( v != f && v != son[u] ) update(v);
	upd(u);
	for ( auto [val,id] : q[u] )
	{
		int res = 0;
		for ( int i = 0 ; i < 26 ; i ++ ) res += ( ( cnt[val] & ( 1 << i ) ) == ( 1 << i ) );  
		ans[id] = res;
	}
	if ( !keep ) update(u);
}

string s;

signed main ()
{
	ios::sync_with_stdio(false);
	cin.tie(nullptr) , cout.tie(nullptr);
	n = read() , m = read();
	for ( int i = 2 , fa ; i <= n ; i ++ ) fa = read() , add ( i , fa ) , add ( fa , i );
	cin >> s; s = " " + s;
	for ( int i = 1 ; i <= n ; i ++ ) a[i] = s[i] - 'a';
	dfs ( 1 , 0 );
	for ( int i = 1 , x , y ; i <= m ; i ++ )
	{
		x = read() , y = read();
		q[x].eb((que){y,i});
	}
	solve ( 1 , 0 , 0 );
	for ( int i = 1 ; i <= m ; i ++ ) cout << ( ans[i] > 1 ? "No" : "Yes" ) << endl;
	return 0;
}
posted @ 2023-07-27 09:59  Echo_Long  阅读(19)  评论(0编辑  收藏  举报