dsu on tree 学习笔记
适用范围
- 支持离线处理
- 每个询问都是针对某棵子树
- 没有修改操作
原理
- 在递归遍历一整棵树并更新每棵子树产生的影响时,正常情况下我们每次递归完一棵子树都要将其撤销,否则会对其兄弟节点造成干扰。
- 而 \(dsu\) \(on\) \(tree\) 仅利用了一个性质,就是在
dfs
的过程中,最后遍历的那个不需要撤销,因为它的兄弟节点已经遍历完了,这时候我们就可以直接让父节点继承这个点的信息。 - 因为最后的子树不需要撤销,所以我们想要让其越大越好,也就是撤销的越少越好,这样会保证我们的时间效率尽可能地高
- 核心也就有了:利用树链剖分的性质,最后递归重儿子且不撤销,优先遍历轻儿子并撤销,父节点继承重儿子的信息
时间复杂度
- 大部分人看到这里都会觉得:这有什么区别吗?不就少搞一个重儿子吗。然而事实上是,它将 \(O(n^2)\) 的效率变成了 \(O(nlogn)\) 的效率
- 依据树链剖分的性质,一个子树内轻边的数量保证不会超过总边数的一半,考虑最坏的情况,每次
dfs
到一个新节点,那么轻边数就会减半,最后遍历完后就只有了 \(logn\) 条轻边
例题
T1 树上数颜色
T2 CF600E Lomsat gelral
T3 射手座之日
这两题都是模板题,这里放上 \(T1\) 的代码作为板子用以参考
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define R register
#define N 100010
using namespace std;
inline int read(){
int x = 0,f = 1;
char ch = getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
int n,m,head[N],siz[N],f[N],son[N],cnt[N],c[N];
long long ans[N];
struct edge{
int to,next;
}e[N<<1];
int len;
void addedge(int u,int v){
e[++len].to = v;
e[len].next = head[u];
head[u] = len;
}
void dfs(int u,int fa){//求重儿子
siz[u] = 1;
f[u] = fa;
for(R int i = head[u];i;i = e[i].next){
int v = e[i].to;
if(v==fa)continue;
dfs(v,u);
siz[u] += siz[v];
if(siz[v]>siz[son[u]])son[u] = v;
}
}
void update(int u,int val,int p){//统计答案,p表示需要跳过的点
cnt[c[u]] += val;
if(val==1&&cnt[c[u]]==1)sum++;
if(val==-1&&cnt[c[u]]==0)sum--;
for(R int i = head[u];i;i = e[i].next){
int v = e[i].to;
if(v==f[u]||v==p)continue;
update(v,val,p);
}
}
void dsu(int u,int opt){//opt表示撤销or不撤销对答案的影响
for(R int i = head[u];i;i = e[i].next){
int v = e[i].to;
if(v==f[u]||v==son[u])continue;//先递归轻儿子
dsu(v,0);
}
if(son[u])dsu(son[u],1);//最后递归重儿子
update(u,1,son[u]);//统计轻儿子的答案,记得跳过重儿子
ans[u] = sum;
if(opt==0){//撤销影响
update(u,-1,0);
sum = 0;
}
}
int main(){
n = read();
for(R int i = 1;i < n;i++){
int u = read(),v = read();
addedge(u,v),addedge(v,u);
}
for(R int i = 1;i <= n;i++)c[i] = read();
dfs(1,0);
dsu(1,0);
m = read();
for(R int i = 1;i <= m;i++){
int x = read();
printf("%lld\n",ans[x]);
}
return 0;
}