树上启发式合并
Dsu on Tree
又名"树上启发式合并" 可以处理静态的树上子树答案统计问题
可以将暴力做法\(O(n^2)\)压到\(O(nlogn)\)
常数小于线段树合并 但是线段树合并支持历史版本 带修等多种操作
U41492 树上数颜色
模板题
我们在\(dfs\)的时候首先\(dfs\)轻节点 记录对应答案后清空 再\(dfs\)重儿子 不清空
统计当前\(u\)节点的答案时 暴力累加轻儿子的贡献 再累加已经求出的重儿子的贡献即可
清空的时候采取暴力清空方式即可
我们的\(solve\)函数中基本操作是:
- 递归\(solve\)轻儿子 统计它们的答案之后清空
- 递归\(solve\)重儿子 统计答案后不清空
- 暴力统计所有轻儿子的贡献 继承重儿子的答案后记录答案
- 如果需要清空贡献 那么全部清空(包括当前节点,重儿子和轻儿子)
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define mid (l+r>>1)
#define eb emplace_back
#define print(x) cout << #x << '=' << x << endl
constexpr int N = 2e5 + 5;
char buf[1<<24] , *p1 , *p2;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
//#define getchar() cin.get()
int read ()
{
int x = 0 , f = 1;
char ch = getchar();
while ( !isdigit ( ch ) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
while ( isdigit ( ch ) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
return x * f;
}
int n , m , col[N] , ans[N] , b[N] , res;
vector<int> e[N];
void add ( int u , int v ) { e[u].eb(v); }
void upd ( int x ) { res += ( ++ b[col[x]] == 1 ); }
void del ( int x ) { res -= ( -- b[col[x]] == 0 ); }
int sz[N] , l[N] , r[N] , rev[N] , timer , son[N];
void dfs ( int u , int f )
{
l[u] = ++timer , rev[timer] = u , sz[u] = 1;
for ( auto v : e[u] )
{
if ( v == f ) continue;
dfs ( v , u );
sz[u] += sz[v];
if ( sz[son[u]] < sz[v] ) son[u] = v;
}
r[u] = timer;
}
void solve ( int u , int f , int keep )
{
for ( auto v : e[u] ) if ( v != f && v != son[u] ) solve ( v , u , 0 );
if ( son[u] ) solve ( son[u] , u , 1 );
for ( auto v : e[u] ) if ( v != f && v != son[u] ) for ( int i = l[v] ; i <= r[v] ; i ++ ) upd ( rev[i] );
upd(u) , ans[u] = res;
if ( !keep ) for ( int i = l[u] ; i <= r[u] ; i ++ ) del ( rev[i] );
}
signed main ()
{
ios::sync_with_stdio(false);
cin.tie(0) , cout.tie(0);
n = read();
for ( int i = 1 , u , v ; i < n ; i ++ ) u = read() , v = read() , add ( u , v ) , add ( v , u );
for ( int i = 1 ; i <= n ; i ++ ) col[i] = read();
dfs ( 1 , 0 ) , solve ( 1 , 0 , 0 );
m = read();
for ( int i = 1 ; i <= m ; i ++ ) cout << ans[read()] << endl;
return 0;
}
Lomsat gelral
板子题
建议在统计贡献的时候用子树的\(dfn\)序来统计 避免了繁琐的递归判断过程(例如下面注释掉的错误\(solve\)函数)
必须注意:统计答案的时候 轻儿子中的重儿子的贡献也是要一并统计进去的
也就是不能简单地用\(son[u]\)来判断是否递归下去
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define eb emplace_back
#define int long long
#define print(x) cerr << #x << '=' << x << endl
constexpr int N = 2e5 + 5;
char buf[1<<24] , *p1 , *p2;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
//#define getchar() cin.get()
int read ()
{
int x = 0 , f = 1;
char ch = getchar();
while ( !isdigit ( ch ) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
while ( isdigit ( ch ) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
return x * f;
}
int n , m , col[N] , ans[N] , b[N];
int l[N] , r[N] , rev[N] , sz[N] , son[N] , maxx , sum , timer;
vector<int> e[N];
void add ( int u , int v ) { e[u].eb(v); }
void dfs ( int u , int f )
{
sz[u] = 1 , l[u] = ++timer , rev[timer] = u;
for ( auto v : e[u] )
{
if ( v == f ) continue;
dfs ( v , u );
sz[u] += sz[v];
if ( sz[son[u]] < sz[v] ) son[u] = v;
}
r[u] = timer;
}
void upd ( int u )
{
b[col[u]] ++;
if ( b[col[u]] > maxx ) sum = col[u] , maxx = b[col[u]];
else if ( b[col[u]] == maxx ) sum += col[u];
}
void del ( int u ) { b[col[u]] --; }
//void solve ( int u , int f , int keep )
//{
// for ( auto v : e[u] ) if ( v != f && v != son[u] ) solve ( v , u , 0 );
// if ( son[u] ) solve ( son[u] , u , 1 );
// upd ( u , f ) , ans[u] = sum;
// if ( !keep ) clear ( u , f ) , sum = 0 , maxx = 0;
//}
void solve ( int u , int f , int keep )
{
for ( auto v : e[u] ) if ( v != f && v != son[u] ) solve ( v , u , 0 );
if ( son[u] ) solve ( son[u] , u , 1 );
for ( auto v : e[u] ) if ( v != f && v != son[u] ) for ( int i = l[v] ; i <= r[v] ; i ++ ) upd(rev[i]);
upd(u); ans[u] = sum;
if ( !keep ) { for ( int i = l[u] ; i <= r[u] ; i ++ ) del(rev[i]); sum = 0 , maxx = 0; }
}
signed main ()
{
// freopen ( "a.in" , "r" , stdin );
ios::sync_with_stdio(false);
cin.tie(0) , cout.tie(0);
n = read();
for ( int i = 1 ; i <= n ; i ++ ) col[i] = read();
for ( int i = 1 , u , v ; i < n ; i ++ ) u = read() , v = read() , add ( u , v ) , add ( v , u );
dfs ( 1 , 0 ) , solve ( 1 , 0 , 0 );
for ( int i = 1 ; i <= n ; i ++ ) cout << ans[i] << ' ';
return 0;
}
XOR Tree
运用了启发式合并的思想
可以发现如果维护一个从上到下的树上前缀和\(dis\) 那么存在异或和为\(0\)的一条路径\((u,v)\)当且仅当\(dis[u]\oplus dis[v]\oplus a[lca(u,v)]=0\)
那么我们可以用\(set\)维护每一个子树中的所有\(dis\)值 到子树的根节点\(u\)的时候 遍历所有的\(v\) 在集合\(v\)中找\(dis[u]\oplus a[lca(u,v)]\) 如果存在 那么这棵子树肯定需要清空
用启发式合并可以压到\(O(nlog^2n)\)
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define eb emplace_back
#define int long long
#define print(x) cerr << #x << '=' << x << endl
constexpr int N = 2e5 + 5;
//char buf[1<<24] , *p1 , *p2;
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
#define getchar() cin.get()
int read ()
{
int x = 0 , f = 1;
char ch = getchar();
while ( !isdigit ( ch ) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
while ( isdigit ( ch ) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
return x * f;
}
set<int> s[N];
vector<int> e[N];
void add ( int u , int v ) { e[u].eb(v); }
int n , a[N] , dis[N] , ans;
void dfs ( int u , int f )
{
int flag = 0;
dis[u] = a[u] ^ dis[f] , s[u].insert(dis[u]);
for ( auto v : e[u] )
{
if ( v == f ) continue;
dfs ( v , u );
if ( s[u].size() < s[v].size() ) swap ( s[u] , s[v] );
for ( auto x : s[v] ) if ( s[u].find(x^a[u]) != s[u].end() ) flag = 1;
for ( auto x : s[v] ) s[u].insert(x);
}
if ( flag ) ans ++ , s[u].clear();
}
signed main ()
{
ios::sync_with_stdio(false);
cin.tie(0) , cout.tie(0);
n = read();
for ( int i = 1 ; i <= n ; i ++ ) a[i] = read();
for ( int i = 1 , u , v ; i < n ; i ++ ) u = read() , v = read() , add ( u , v ) , add ( v , u );
dfs ( 1 , 0 );
cout << ans << endl;
return 0;
}
Blood Cousins
\(Dsu\ on\ tree\) 还有线段树合并写法(在另一篇\(blog\)中)
还是要注意\(jump\)函数的写法罢(\(dep\) \(rev\) \(top\)这些数组要分清)
#include <bits/stdc++.h>
using namespace std;
#define mid (l+r>>1)
#define endl '\n'
#define inl inline
#define ls(p) t[p].son[0]
#define rs(p) t[p].son[1]
#define lson ls(p),l,mid
#define rson rs(p),mid+1,r
#define eb emplace_back
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 5;
char buf[1<<24] , *p1 , *p2;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
//#define getchar() cin.get();
int read()
{
int x = 0 , f = 1;
char ch = getchar();
while ( !isdigit(ch) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
while ( isdigit(ch) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
return x * f ;
}
int n , m , ans[N] , cnt[N] , l[N] , r[N];
int fa[N] , dep[N] , sz[N] , son[N];
int pos[N] , rev[N] , top[N] , timer;
struct que { int val , id; };
vector<que> q[N];
vector<int> e[N];
inl void add ( int u , int v ) { e[u].eb(v); }
struct LCA
{
void dfs1 ( int u , int f )
{
dep[u] = dep[f] + 1 , fa[u] = f , sz[u] = 1;
for ( auto v : e[u] )
if ( v ^ f )
{
dfs1 ( v , u );
sz[u] += sz[v];
if ( sz[son[u]] < sz[v] ) son[u] = v;//存疑
}
}
void dfs2 ( int u , int tp )
{
top[u] = tp , pos[u] = ++timer , rev[timer] = u;
if ( son[u] ) dfs2 ( son[u] , tp );
for ( auto v : e[u] ) if ( v != fa[u] && v != son[u] ) dfs2 ( v , v );
}
int jump ( int u , int k )
{
while ( k > 0 && u != 0 )
{
if ( pos[u] - pos[top[u]] + 1 > k ) return rev[pos[u]-k];
k -= pos[u] - pos[top[u]] + 1;
u = fa[top[u]];
}
return u;
}
}L;
inl void add ( int x ) { cnt[dep[x]] ++; }
inl void del ( int x ) { cnt[dep[x]] --; }
void dfs ( int u , int f , int keep )
{
for ( auto v : e[u] ) if ( v != son[u] && v != f ) dfs ( v , u , 0 );
if ( son[u] ) dfs ( son[u] , u , 1 );
for ( auto v : e[u] ) if ( v != son[u] && v != f ) for ( int i = pos[v] ; i <= pos[v] + sz[v] - 1 ; i ++ ) add(rev[i]);
add(u);
for ( auto [val,id] : q[u] ) ans[id] = cnt[val];
if ( !keep ) for ( int i = pos[u] ; i <= pos[u] + sz[u] - 1 ; i ++ ) del(rev[i]);
}
signed main ()
{
ios::sync_with_stdio(false);
cin.tie(nullptr) , cout.tie(nullptr);
n = read();
for ( int i = 1 , fa ; i <= n ; i ++ ) fa = read() , add ( fa , i ) , add ( i , fa );
for ( auto v : e[0] ) L.dfs1 ( v , 0 );
for ( auto v : e[0] ) L.dfs2 ( v , v );
m = read();
for ( int i = 1 , x , y ; i <= m ; i ++ )
{
x = read() , y = read();
int lca = L.jump ( x , y );
if ( lca ) q[lca].eb((que){dep[x],i});
else ans[i] = 1;
}
for ( auto v : e[0] ) dfs ( v , 0 , 0 );
for ( int i = 1 ; i <= m ; i ++ ) cout << ans[i] - 1 << ' ';
return 0;
}
P5384 [Cnoi2019] 雪松果树
同上一道题 但是线段树合并会被卡 \(dsu\)可以过
#include <bits/stdc++.h>
using namespace std;
#define mid (l+r>>1)
#define endl '\n'
#define inl inline
#define ls(p) t[p].son[0]
#define rs(p) t[p].son[1]
#define lson ls(p),l,mid
#define rson rs(p),mid+1,r
#define eb emplace_back
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const int N = 1e6 + 5;
char buf[1<<24] , *p1 , *p2;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
//#define getchar() cin.get();
int read()
{
int x = 0 , f = 1;
char ch = getchar();
while ( !isdigit(ch) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
while ( isdigit(ch) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
return x * f ;
}
int n , m , ans[N] , cnt[N] , l[N] , r[N];
int fa[N] , dep[N] , sz[N] , son[N];
int pos[N] , rev[N] , top[N] , timer;
struct que { int val , id; };
vector<que> q[N];
vector<int> e[N];
inl void add ( int u , int v ) { e[u].eb(v); }
struct LCA
{
void dfs1 ( int u , int f )
{
dep[u] = dep[f] + 1 , fa[u] = f , sz[u] = 1;
for ( auto v : e[u] )
if ( v ^ f )
{
dfs1 ( v , u );
sz[u] += sz[v];
if ( sz[son[u]] < sz[v] ) son[u] = v;//存疑
}
}
void dfs2 ( int u , int tp )
{
top[u] = tp , pos[u] = ++timer , rev[timer] = u;
if ( son[u] ) dfs2 ( son[u] , tp );
for ( auto v : e[u] ) if ( v != fa[u] && v != son[u] ) dfs2 ( v , v );
}
int jump ( int u , int k )
{
while ( k > 0 && u != 0 )
{
if ( pos[u] - pos[top[u]] + 1 > k ) return rev[pos[u]-k];
k -= pos[u] - pos[top[u]] + 1;
u = fa[top[u]];
}
return u;
}
}L;
inl void add ( int x ) { cnt[dep[x]] ++; }
inl void del ( int x ) { cnt[dep[x]] --; }
void dfs ( int u , int f , int keep )
{
for ( auto v : e[u] ) if ( v != son[u] && v != f ) dfs ( v , u , 0 );
if ( son[u] ) dfs ( son[u] , u , 1 );
for ( auto v : e[u] ) if ( v != son[u] && v != f ) for ( int i = pos[v] ; i <= pos[v] + sz[v] - 1 ; i ++ ) add(rev[i]);
add(u);
for ( auto [val,id] : q[u] ) ans[id] = cnt[val];
if ( !keep ) for ( int i = pos[u] ; i <= pos[u] + sz[u] - 1 ; i ++ ) del(rev[i]);
}
signed main ()
{
ios::sync_with_stdio(false);
cin.tie(nullptr) , cout.tie(nullptr);
n = read();m = read();
for ( int i = 2 , fa ; i <= n ; i ++ ) fa = read() , add ( fa , i ) , add ( i , fa );
L.dfs1 ( 1 , 0 ) , L.dfs2 ( 1 , 1 );
for ( int i = 1 , x , y ; i <= m ; i ++ )
{
x = read() , y = read();
int lca = L.jump ( x , y );
if ( lca ) q[lca].eb((que){dep[x],i});
else ans[i] = 1;
}
dfs ( 1 , 0 , 0 );
for ( int i = 1 ; i <= m ; i ++ ) cout << ans[i] - 1 << ' ';
return 0;
}
Tree Requests
突然感觉\(dsu\)代码更短常数空间更小细节更少一点... 那以后就用\(dsu\)了
线段树合并垃圾回收的时候不能按照正常思路写真的烦人()
对于每一个节点状压维护一个二进制数 表示这个点的字符个数是奇数/偶数个 统计答案的时候如果子树中的字符有两个或以上为奇数的 那么不可以 否则可以
#include <bits/stdc++.h>
using namespace std;
#define mid (l+r>>1)
#define endl '\n'
#define inl inline
#define ls(p) t[p].son[0]
#define rs(p) t[p].son[1]
#define lson ls(p),l,mid
#define rson rs(p),mid+1,r
#define eb emplace_back
const int inf = 0x3f3f3f3f;
const int N = 5e5 + 5;
//char buf[1<<24] , *p1 , *p2;
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<24,stdin),p1==p2)?EOF:*p1++)
#define getchar() cin.get();
int read()
{
int x = 0 , f = 1;
char ch = getchar();
while ( !isdigit(ch) ) { if ( ch == '-' ) f = -1; ch = getchar(); }
while ( isdigit(ch) ) { x = ( x << 1 ) + ( x << 3 ) + ( ch ^ 48 ); ch = getchar(); }
return x * f ;
}
int n , m , dep[N] , timer , rev[N] , son[N] , sz[N] , l[N] , r[N] , a[N];
int cnt[N] , ans[N];
struct que { int val , id; };
vector<que> q[N];
vector<int> e[N];
inl void add ( int u , int v ) { e[u].eb(v); }
void dfs ( int u , int f )
{
dep[u] = dep[f] + 1 , sz[u] = 1 , l[u] = ++timer , rev[timer] = u;
for ( int v : e[u] )
if ( v != f )
{
dfs ( v , u );
sz[u] += sz[v];
if ( sz[son[u]] < sz[v] ) son[u] = v;
}
r[u] = timer;
}
void upd ( int u ) { cnt[dep[u]] ^= ( 1 << a[u] ); }
void update ( int u ) { for ( int i = l[u] ; i <= r[u] ; i ++ ) upd ( rev[i] ); }
void solve ( int u , int f , int keep )
{
for ( int v : e[u] ) if ( v != f && v != son[u] ) solve ( v , u , 0 );
if ( son[u] ) solve ( son[u] , u , 1 );
for ( int v : e[u] ) if ( v != f && v != son[u] ) update(v);
upd(u);
for ( auto [val,id] : q[u] )
{
int res = 0;
for ( int i = 0 ; i < 26 ; i ++ ) res += ( ( cnt[val] & ( 1 << i ) ) == ( 1 << i ) );
ans[id] = res;
}
if ( !keep ) update(u);
}
string s;
signed main ()
{
ios::sync_with_stdio(false);
cin.tie(nullptr) , cout.tie(nullptr);
n = read() , m = read();
for ( int i = 2 , fa ; i <= n ; i ++ ) fa = read() , add ( i , fa ) , add ( fa , i );
cin >> s; s = " " + s;
for ( int i = 1 ; i <= n ; i ++ ) a[i] = s[i] - 'a';
dfs ( 1 , 0 );
for ( int i = 1 , x , y ; i <= m ; i ++ )
{
x = read() , y = read();
q[x].eb((que){y,i});
}
solve ( 1 , 0 , 0 );
for ( int i = 1 ; i <= m ; i ++ ) cout << ( ans[i] > 1 ? "No" : "Yes" ) << endl;
return 0;
}