[OI] 树链剖分
学的时候比较朦胧,现在不朦胧了,所以写一下
讲解
重儿子:一个节点的子树大小最大的儿子
轻儿子:非重儿子
重链:节点 -> 重儿子 -> 重儿子 .. 这样的链
A beautiful Tree
蓝线为重链
可以发现,树上的所有节点一定属于且仅属于一个重链
首先要知道如何找重链
这很简单,可以通过一遍 DFS 得到:
void dfs(int now){
size[now]=1;
int maxsonsize=0;
for(i:遍历所有子节点){
dfs(i)
if(size[i]>maxsonsize){
maxson[now]=i;
maxsonsize=size[i]
}
size[now]+=size[i]
}
}
其中 size
是节点的子树大小
为什么一定要剖出重链来?因为我们要进行的是在链上跳跃的操作,而我们可以跳跃的范围是一整条链,因此链越长,对复杂度就越有利,而且我们将不同的重链剖出来,还能保证每一个节点都在一条重链上,不重不漏
找出重儿子以后怎么找重链呢
这个就更简单了,我们再做一遍 DFS,记录每个链顶端的节点,然后将其赋给链中的每一个节点(或者你在这里开个 cnt 也是可以的,只要能起到区分的作用就行),这样,值相同的节点就一定在同一条链里了
void dfs2(int now,int topnode){
top[now]=topnode;
if(maxson[now]==0) return; //没有儿子就返回
dfs2(maxson[now],top_node) //搜索重儿子,此时不改变链
for(i:遍历子节点){
if(i!=maxson[now]){
dfs2(i,i); //轻儿子的重链顶端就是这个轻儿子,可以看上面的图
} //如果你在这里写 cnt 的话就是 ++cnt
}
}
实际上我们还需要在这两遍 DFS 中维护一些信息,具体的信息列在下面:
DFS1
- 节点父亲
- 节点深度
- 节点子树大小
- 节点的重儿子编号
DFS2
- 构建链
- 按遍历顺序为每个节点分配新编号
- 将原节点权值迁移到新编号
可以写出下面两个代码:
void dfs1(int now,int last){
fa[now]=last;
deep[now]=deep[last]+1;
size[now]=1;
int maxsonsize=0;
for(int i:e[now]){
if(i!=last){
dfs1(i,now);
if(size[i]>maxsonsize){
maxson[now]=i;
maxsonsize=size[i];
}
size[now]+=size[i];
}
}
}
void dfs2(int now,int nowtop,int last){
id[now]=++cnt;
wnew[id[now]]=w[now];
top[now]=nowtop;
if(!maxson[now]) return;
dfs2(maxson[now],nowtop,now);
for(int i:e[now]){
if(i!=last and i!=maxson[now]){
dfs2(i,i,now);
}
}
}
这里我们给每个节点都分配了新的编号,那么有什么用吗
因为我们这么分配编号有两个非常好的性质
-
同一个重链上的点,编号总是连续的,并且上面的节点编号总是比下面的节点编号要小
-
同一个子树中的点,编号是一个连续区间,并且这个区间总是 \([id_r,id_r+size-1]\)(\(r\) 是子树根节点)
但是需要注意的是,为了实现这两个非常好的性质,我们需要在 DFS2 中优先遍历重儿子,因为重儿子和当前节点在一条链中,只有优先遍历了重儿子,才能保证按遍历顺序分配的编号是连续的
那么有了这两个非常好的性质,我们可以干什么呢
- 查询路径信息
假如有一道题让我们查询树上 \((x,y)\) 的简单路径权值和(点权)
那么我们可以考虑这样降低复杂度:
- 如果 \(x,y\) 不在一条链上,将其中链顶深度较小的那个节点跳到它所在的链顶,同时统计该节点到其顶端的答案
- 重复如上操作,直到 \(x,y\) 在一条链上
- 此时直接统计即可
以上操作中,由于我们只在同一条链上跳,因此编号总是连续的,所以可以用数据结构来维护
下面是一份线段树维护的查询
int ask_path_sum(int x,int y){
int res=0;
while(top[x]!=top[y]){
if(deep[top[x]]<deep[top[y]]) swap(x,y);
res+=ask_sum(1,id[top[x]],id[x]);
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
res+=ask_sum(1,id[y],id[x]);
return res;
}
路径修改同理
然后考虑怎么用第二个性质
第二个性质也非常好,可以用来作子树整体修改/查询
int ask_subtree(int x){
return stree::ask_sum(1,id[x],id[x]+size[x]-1);
}
例题
树的统计
- 单点修改
- 路径和查询
- 路径最值查询
这两个信息都能用线段树来维护
单点修改总是简单的,直接在线段树上定位即可
#include<bits/stdc++.h>
using namespace std;
#define int long long
int deep[200001],fa[200001],size[200001],maxson[200001];
vector<int>e[200001];
void dfs1(int now,int last){
fa[now]=last;
deep[now]=deep[last]+1;
size[now]=1;
int maxsonsize=0;
for(int i:e[now]){
if(i!=last){
dfs1(i,now);
if(size[i]>maxsonsize){
maxson[now]=i;
maxsonsize=size[i];
}
size[now]+=size[i];
}
}
}
int w[200001];
int id[200001],top[200001],wnew[200001];
int cnt=0;
void dfs2(int now,int nowtop,int last){
id[now]=++cnt;
wnew[id[now]]=w[now];
top[now]=nowtop;
if(!maxson[now]) return;
dfs2(maxson[now],nowtop,now);
for(int i:e[now]){
if(i!=last and i!=maxson[now]){
dfs2(i,i,now);
}
}
}
namespace stree{
struct tree{
int l,r;
int sum,max;
}t[800001];
#define tol (id*2)
#define tor (id*2+1)
#define mid(l,r) mid=((l)+(r))/2
void build(int id,int l,int r){
t[id].l=l;t[id].r=r;
if(l==r){
t[id].sum=wnew[l];
t[id].max=wnew[l];
return;
}
int mid(l,r);
build(tol,l,mid);
build(tor,mid+1,r);
t[id].sum=(t[tol].sum+t[tor].sum);
t[id].max=max(t[tol].max,t[tor].max);
}
int ask_sum(int id,int l,int r){
if(l>r) swap(l,r);
if(l<=t[id].l and t[id].r<=r){
return t[id].sum;
}
pushdown(id);
if(r<=t[tol].r) return ask_sum(tol,l,r);
else if(l>=t[tor].l) return ask_sum(tor,l,r);
else{
return (ask_sum(tol,l,t[tol].r)+ask_sum(tor,t[tor].l,r));
}
}
int ask_max(int id,int l,int r){
if(l>r) swap(l,r);
if(l<=t[id].l and t[id].r<=r){
return t[id].max;
}
pushdown(id);
if(r<=t[tol].r) return ask_max(tol,l,r);
else if(l>=t[tor].l) return ask_max(tor,l,r);
else{
return max(ask_max(tol,l,t[tol].r),ask_max(tor,t[tor].l,r));
}
}
}
int ask_path_max(int x,int y){
int res=-1;
while(top[x]!=top[y]){
if(deep[top[x]]<deep[top[y]]) swap(x,y);
res=max(res,stree::ask_max(1,id[top[x]],id[x]));
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
res=max(res,stree::ask_max(1,id[y],id[x]));
return res;
}
int ask_path_sum(int x,int y){
int res=0;
while(top[x]!=top[y]){
if(deep[top[x]]<deep[top[y]]) swap(x,y);
res+=stree::ask_sum(1,id[top[x]],id[x]);
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
res+=stree::ask_sum(1,id[y],id[x]);
return res;
}
int n,m;
signed main(){
scanf("%lld",&n);
for(int i=1;i<=n-1;++i){
int x,y;
scanf("%lld %lld",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
for(int i=1;i<=n;++i){
scanf("%lld",&w[i]);
}
scanf("%lld",&m);
dfs1(1,0);
dfs2(1,1,0);
stree::build(1,1,n);
while(m--){
string op;int x,y,z;cin>>op;
if(op[0]=='C'){
scanf("%lld %lld",&x,&z);
stree::change(1,id[x],id[x],z-stree::ask_sum(1,id[x],id[x]));
}
if(op[0]=='Q' and op[1]=='M'){
scanf("%lld %lld",&x,&y);
printf("%lld\n",ask_path_max(x,y));
}
if(op[0]=='Q' and op[1]=='S'){
scanf("%lld %lld",&x,&y);
printf("%lld\n",ask_path_sum(x,y));
}
}
}