dsu on tree
\(dsu on tree\)
定义:
这里简称 \(dsu\) 好了。
这个主要是用来解决一类树上询问问题,一般有两个特征:
- 只有对子树的询问
- 没有修改
这时候就可以用 \(dsu\) 了。
可能特征 \(1\) 不显然,题目中不明确问你子树 \(i\) 的答案,需要把问题转化后算子树 \(i\) 的答案。
算法流程:
我们以一道例题为例子:
大意:有一棵 \(n\) 个节点的树,一号节点为根节点,每个点都有一种颜色,询问节点 \(i\) 的子树中颜色种类数最多的子树,大小是多少。
考虑暴力:
遍历每一个节点,把子树中所有颜色暴力统计更新答案,消除该节点贡献,继续统计。 时间复杂度为 \(O(n^2)\)
\(dsu\) 利用轻重链剖分,把复杂度降低到 \(O( /log n)\)
流程:
· 遍历每一个节点
递归解决所有的轻儿子,同时消除递归产生的影响
· 递归重儿子,不消除递归影响
· 统计所有轻儿子对答案的影响
· 更新该节点答案
· 删除所有轻儿子对答案影响
主体框架:
void dfs(int x,int fa,int opt){
for(int i=head[x];i;i=nxt[i]){
int y=ver[i]; if(y==fa) continue;
if(y!=son[x]) dfs(y,x,0);//暴力统计轻边贡献
}
if(son[x]) dfs(son[x],x,1);//统计重儿子贡献,不消除影响
add(x);//暴力统计所有轻儿子贡献
ans[x]=nowans;//更新答案
if(!opt) delete(x); //需要删除贡献就删掉。
}
时间复杂度为 \(O(n\log n)\)
例题:
题意:
给出一棵 \(n\) 个结点的树,每个结点都有一种颜色编号,求该树中每棵子树里的出现次数最多的颜色的编号和。
流程:
首先,预处理出每个点的重儿子 \(son[x]\)
对于一个点 \(x\) ,首先计算他所有轻儿子的 \(ans\) ,并且每算完一个儿子就要清楚它所有的贡献。
for(int i=head[x];i;i=nxt[i]){//计算轻儿子的答案并清空贡献
int y=ver[i];
if(y==fa||y==son[x]) continue;
solve(y,x);
cl();
}
接下来,计算重儿子 \(son[x]\) 的答案,并保留 \(subtree(son[x])\) (重儿子子树) 中所有点的贡献。
if(son[x]) solve(son[x],x);//不清除
然后再暴力加入 \(subtree(x)\) 中除了 \(subtree(son[x])\) 以外所有点的贡献,此时就可以得到 \(ans[x]\)
for(int i=head[x];i;i=nxt[i]){//暴力统计轻儿子贡献
int y=ver[i];
if(y==fa||y==son[x]) continue;
addson(y,x);
}
insert(x);//注意还要把x的贡献也加进去
Ans[x]=ans;
之后回溯到 \(fa[x]\) ,如果 \(x\) 是 \(fa[x]\) 轻儿子,那就清除贡献,否则保留。
其中,统计答案用到了反复更新加权的写法:
当答案更加优秀,更新优秀答案; 答案一样优秀,优秀答案的次数增加。
时间复杂度分析:
感觉是不是会爆炸?其实不会。
通过输出转移,我们发现:
每一个点只会在其到根路径中若干重链交界处被统计(也即是从该点到根路径上的轻边数量)
因为重链剖分性质,一个点向上跳时,重链数量不超过 \(\log n\),因此轻边的数量也不会超过 \(\log n\), 因此每个点的统计次数也为 \(\log n\)。
因此,总复杂度为 \(O(n \log n)\) ,并且带上一些常数。
代码:
#include<bits/stdc++.h>
using namespace std;
const int N=4e5+5;
#define ll long long
int n;
int A[N],t,tmp;
int nxt[N],ver[N],tot,head[N];
int sizes[N],son[N],cnt[N],Q[N],edg;
ll ans,Ans[N];
void add(int x,int y){
ver[++tot]=y; nxt[tot]=head[x]; head[x]=tot;
}
void dfs(int x,int fa){
sizes[x]=1;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i]; if(y==fa) continue;
dfs(y,x);
sizes[x]+=sizes[y];
if(sizes[y]>sizes[son[x]]||!son[x]) son[x]=y;
}
}
void cl(){
while(t) cnt[Q[t--]]=0; ans=tmp=0;
}
void insert(int x){
Q[++t]=A[x];
cnt[Q[t]]++;
//cnt记录颜色数量,tmp记录子树中最大颜色数,ans记录编号之和
if(cnt[A[x]]>tmp) tmp=cnt[ans=A[x]];
else if(cnt[A[x]]==tmp) ans+=A[x];
}
void addson(int x,int fa){
insert(x);//暴力统计轻儿子答案
for(int i=head[x];i;i=nxt[i]){
int y=ver[i]; if(y==fa) continue;
addson(y,x);
}
}
void solve(int x,int fa){
// cout<<++edg<<" "<<x<<" "<<fa<<" "<<Ans[x]<<endl;
for(int i=head[x];i;i=nxt[i]){//计算轻儿子的答案并清空贡献
int y=ver[i];
if(y==fa||y==son[x]) continue;
solve(y,x);
cl();
}
if(son[x]) solve(son[x],x);
//计算重儿子的答案并保留subtree(son[x])(以son[x]为根的整棵子树贡献)
for(int i=head[x];i;i=nxt[i]){//暴力统计轻儿子贡献
int y=ver[i];
if(y==fa||y==son[x]) continue;
addson(y,x);
//加入subtree(x)-subtree(son[x])-x(以x的所有轻儿子为根的子树贡献)
}
insert(x);//注意还要把x的贡献也加进去
Ans[x]=ans;
}
int main(){
cin>>n;
for(int i=1;i<=n;i++) scanf("%d",&A[i]);
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y); add(x,y); add(y,x);
}
dfs(1,0);
solve(1,0);
for(int i=1;i<=n;i++)
printf("%lld ",Ans[i]); puts("");
system("pause");
return 0;
}
CF741D Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths
题意:
给你一个有向树,每个节点表示 \([a,v]\) 之间的一个字母,让你算出每个节点的子树的字母重新排列后是否能形成一个回文串,如果可以,输出这个回文串的长度。
分析:
考虑什么时候形成回文串:
- 所有字母出现次数被 \(2\) 整除
- 只有一个字母出现次数为奇数,其他都为 \(2\) 的倍数。
考虑 \([a,v]\) 一共 \(22\) 位,考虑用状压存储,\(x\) 子树能形成回文串的条件就是:
\(val[x]^val[y]=0/2^k\)
考虑使用 \(dsu\) 解决问题:
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+5,M=1<<22,inf=0x3f3f3f3f;
int nxt[N],ver[N],tot,head[N],edge[N];
int sizes[N],dep[N],son[N],cnt;
int n,m;
int D[(M)+5],now,Ans[N],mx,col[M];
void add(int x,int y,int z){
ver[++tot]=y; edge[tot]=z; nxt[tot]=head[x]; head[x]=tot;
}
void dfs(int x){
sizes[x]=1;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i],z=edge[i];
D[y]=D[x]^(1<<z);
dep[y]=dep[x]+1;
dfs(y);
sizes[x]+=sizes[y];
if(sizes[y]>sizes[son[x]]||!son[x]) son[x]=y;
}
}
void clear(int x){
col[D[x]]=0;
for(int i=head[x];i;i=nxt[i]) clear(ver[i]);
}
void cal(int x){
if(col[D[x]]) mx=max(mx,dep[x]+col[D[x]]-now);
for(int i=0;i<22;i++)
if(col[(1<<i)^D[x]]) mx=max(mx,dep[x]+col[(1<<i)^D[x]]-now);
}
void upd(int x){
col[D[x]]=max(dep[x],col[D[x]]);//更新
}
void calc(int x){
cal(x);
for(int i=head[x];i;i=nxt[i]) calc(ver[i]);
}
void update(int x){
upd(x);
for(int i=head[x];i;i=nxt[i]) update(ver[i]);
}
void solve(int x,int fa){
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y==son[x]) continue;
solve(y,x);
clear(y);
}
if(son[x]) solve(son[x],x);
now=dep[x]<<1;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i]; mx=max(Ans[y],mx);
}
for(int i=head[x];i;i=nxt[i]){
int y=ver[i]; if(y==son[x]) continue;
calc(y); update(y);
}
cal(x); upd(x);
Ans[x]=mx;
if(x!=son[fa]) clear(x),mx=0;
}
int main(){
cin>>n;
for(int i=2,y;i<=n;i++){ char ch[3];
scanf("%d%s",&y,ch);
add(y,i,ch[0]-'a');
}
dfs(1);
// for(int i=0;i<=n;i++) cout<<D[i]<<" "; puts("");
solve(1,1);
for(int i=1;i<=n;i++) printf("%d ",Ans[i]); puts("");
system("pause");
return 0;
}