树链剖分(附带LCA和换根)——基于dfs序的树上优化
。。。。
有点懒;
需要先理解几个概念:
1. LCA
2. 线段树(熟练,要不代码能调一天)
3. 图论的基本知识(dfs序的性质)
这大概就好了;
定义:
1.重儿子:一个点所连点树size最大的,这个son被称为这个点的重儿子;
2.轻儿子:一个点所连点除重儿子以外的都是轻儿子;
3.重链:从一个轻儿子或根节点开始沿重儿子走所成的链;
步骤:
在代码里,结合代码更清晰。。。(其实是太懒了)
有重点需要注意的东西在code中有提到,仔细看。。。。
#include<bits/stdc++.h> #define maxn 100007 #define le(x) x<<1 #define re(x) x<<1|1 using namespace std; int n,m,root,mod,a[maxn],head[maxn],fa[maxn],son[maxn],cnt,tag[maxn<<2]; //a:原始点值,fa:父亲节点,son:重儿子,tag:懒标记 int top[maxn],sz[maxn],id[maxn],dep[maxn],w[maxn],cent,tr[maxn<<2]; //top:所在重链的头结点,sz:子树大小,id:dfs序,dep:深度 //w:dfs序所对应的值(建线段树),tr:线段树 struct node{ int next,to; }edge[maxn<<2]; template<typename type_of_scan> inline void scan(type_of_scan &x){ type_of_scan f=1;x=0;char s=getchar(); while(s<'0'||s>'9') f=s=='-'?-1:1,s=getchar(); while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+s-'0',s=getchar(); x*=f; } inline void add(int u,int v){ edge[++cent]=(node){head[u],v};head[u]=cent; } //-----------------------------------------------------线段树红色预警 void push_up(int p){ tr[p]=tr[le(p)]+tr[re(p)]; tr[p]%=mod; } void build(int l,int r,int p){ if(l==r){ tr[p]=w[l]; return ; } int mid=(l+r)>>1; build(l,mid,le(p)); build(mid+1,r,re(p)); push_up(p); } void push_down(int l,int r,int p,int k){ int mid=l+r>>1; tr[le(p)]+=k*(mid-l+1),tr[re(p)]+=k*(r-mid); tr[le(p)]%=mod,tr[re(p)]%=mod; tag[le(p)]+=k,tag[re(p)]+=k; tag[le(p)]%=mod,tag[re(p)]%=mod; } void r_add(int nl,int nr,int l,int r,int p,int k){ if(nl<=l&&nr>=r){ tr[p]+=k*(r-l+1);tag[p]+=k; tr[p]%=mod,tag[p]%=mod; return ; } push_down(l,r,p,tag[p]),tag[p]=0; int mid=(l+r)>>1; if(nl<=mid) r_add(nl,nr,l,mid,le(p),k); if(nr>mid) r_add(nl,nr,mid+1,r,re(p),k); push_up(p); } int r_query(int nl,int nr,int l,int r,int p){ int ans=0; if(nl<=l&&nr>=r) return tr[p]; push_down(l,r,p,tag[p]),tag[p]=0; int mid=l+r>>1; if(nl<=mid) ans+=r_query(nl,nr,l,mid,le(p)),ans%=mod; if(nr>mid) ans+=r_query(nl,nr,mid+1,r,re(p)),ans%=mod; push_up(p); return ans; } //-----------------------------------------------------线段树结束 //-----------------------------------------------------开始预处理 void dfs1(int x){ sz[x]=1;//sz初始化 int max_part=-1;//max_part更新寻找重儿子 for(int i=head[x];i;i=edge[i].next){ int y=edge[i].to; if(y==fa[x]) continue; fa[y]=x,dep[y]+=dep[x]+1;//更新子节点,准备开始继续dfs1 dfs1(y);sz[x]+=sz[y];//更新自身的sz数组 if(max_part<sz[y]) son[x]=y,max_part=sz[y];//更新重儿子 } } /*dfs1功能介绍 1.更新fa数组; 2.更新dep数组; 3.更新sz数组; 4.更新son数组; */ void dfs2(int x,int t){ id[x]=++cnt,w[cnt]=a[x],top[x]=t;//更新dfs序,dfs序所对的值,重链头节点 if(!son[x]) return ; dfs2(son[x],t); for(int i=head[x];i;i=edge[i].next){ int y=edge[i].to; if(y==fa[x]||y==son[x]) continue; dfs2(y,y); } } /*dfs2功能介绍 1.更新id数组; 2.更新w数组; 3.更新top数组 */ //------------------------------------------------预处理结束 //------------------------------------------------开始主要操作 //其实没有说的这么简单,这里重点是理解重链之间的跳跃方式,线段树的优化 //一个性质:重链上的dfs序是连续的,dfs1在dfs2前的原因就在此 int road_query(int x,int y){ int ans=0; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下面往上跳 ans+=r_query(id[top[x]],id[x],1,n,1);//更新重链 ans%=mod; x=fa[top[x]];//跳到重链头的fa } if(dep[x]>dep[y]) swap(x,y); ans+=r_query(id[x],id[y],1,n,1);//已经在同一条重链上,直接加 return ans%mod; } int tree_query(int x){ return r_query(id[x],id[x]+sz[x]-1,1,n,1)%mod; }//一个性质:在同一颗子树上的dfs序是连续的 void road_add(int x,int y,int k){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); r_add(id[top[x]],id[x],1,n,1,k); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); r_add(id[x],id[y],1,n,1,k); return ; }//类比 void tree_add(int x,int k){ r_add(id[x],id[x]+sz[x]-1,1,n,1,k); return ; }//相同的性质 //-----------------------------------------------树链剖分 int main(){ scan(n),scan(m),scan(root),scan(mod); for(int i=1;i<=n;i++) scan(a[i]); for(int i=1,u,v;i<=n-1;i++) scan(u),scan(v),add(u,v),add(v,u); dfs1(root),dfs2(root,root),build(1,n,1); for(int i=1;i<=m;i++){ int type,x,y,z; scan(type); if(type==1) scan(x),scan(y),scan(z), road_add(x,y,z); else if(type==2) scan(x),scan(y), printf("%d\n",road_query(x,y)); else if(type==3) scan(x),scan(z), tree_add(x,z); else if(type==4) scan(x), printf("%d\n",tree_query(x)); } return 0; }
好了,可以开始调代码了
拓展:
树链剖分,作为一个优秀的暴力结构,以O(n logn logn)的时间复杂度完成路径查询,在子树查询做到了nlogn级别,所以不得不说其优秀;
但是,它的作用远不及此:
1.LCA查询:
与倍增相同,树链剖分可以用logn的时间复杂度完成LCA查询(跳跃性好像更优),而他的初始化是两遍dfs O(n),理论上更优。
可以猜测,LCA依旧运用重链跳法,然后比较即可,这里给出示范代码
int Lca(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); x=fa[top[x]]; } return dep[x]>dep[y]?y:x; }//只要看懂树链剖分的基本操作,这个很简单
可以看到,其实代码很短。。。
2.换根操作:
设现在的根是root,我们可以发现,换根对于路径上的操作并没有影响,但是子树操作就会影响了,所以我们分类讨论
设u为我们要查的子树的根节点
(1)如果root=u,那么子树即为整棵树;
(2)设 lca 为root和u的LCA,这里可以用上面所讲的树链剖分做,如果lca!=u,那么root并不是u的子节点,所以对于查询并不影响,常规操作即可
(3)如果lca=u,那么u节点的子树就是整颗树减去u-root这个路径上与u相挨的节点v的子树即可,这里给出logn求点v的方法
//前提条件:要求的节点相挨的节点u,必须是root的LCA int find(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下往上跳 if(fa[top[x]]==y) return top[x];//如果y是x所在重链top的父亲节点,那么就可以返回了 x=fa[top[x]];//跳 } if(dep[x]<dep[y]) swap(x,y);//让y最浅 return son[y];// 因为在一条重链上,那么重儿子一定是路径上与要求节点相挨的 }
整个操作的代码层次感我写的还是比较清楚了
void tree_add(int x,int k){ if(root==x) r_add(1,n,1,n,1,k);//CASE 1 else{ int lca=Lca(x,root); if(lca!=x) r_add(id[x],id[x]+sz[x]-1,1,n,1,k);//CASE 2 else{ int dson=find(x,root); r_add(1,n,1,n,1,k); r_add(id[dson],id[dson]+sz[dson]-1,1,n,1,-k); }//CASE 3 } return ; } ll tree_query(int x){ if(root==x) return r_query(1,n,1,n,1);//CASE 1 else{ int lca=Lca(x,root); if(lca!=x) return r_query(id[x],id[x]+sz[x]-1,1,n,1);//CASE 2 else{ int dson=find(x,root); return r_query(1,n,1,n,1)-r_query(id[dson],id[dson]+sz[dson]-1,1,n,1); }//CASE 3 } }
推荐评测网站LOJ 。。。(因为洛谷没有换根操作)
AC代码附上
#include<bits/stdc++.h> #define maxn 100007 #define ol putchar('\n') #define le(x) x<<1 #define re(x) x<<1|1 #define ll long long using namespace std; int n,m,head[maxn],cent,dep[maxn],son[maxn],fa[maxn],vis[maxn]; int top[maxn],a[maxn],id[maxn],w[maxn],sz[maxn],cnt,ij,root; ll tr[maxn<<3],tag[maxn<<3]; struct node{ int next,to; }edge[maxn<<3]; template<typename type_of_scan> inline void scan(type_of_scan &x){ type_of_scan f=1;x=0;char s=getchar(); while(s<'0'||s>'9') f=s=='-'?-1:1,s=getchar(); while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+s-'0',s=getchar(); x*=f; } template<typename type_of_print> inline void print(type_of_print x){ if(x<0) putchar('-'),x=-x; if(x>9) print(x/10); putchar(x%10+'0'); } inline void add(int u,int v){ edge[++cent]=(node){head[u],v};head[u]=cent; } void push_up(int p){ tr[p]=tr[le(p)]+tr[re(p)]; } void push_down(int l,int r,int p,ll k){ int mid=l+r>>1; tr[le(p)]+=1ll*(mid-l+1)*k, tr[re(p)]+=1ll*(r-mid)*k, tag[le(p)]+=k,tag[re(p)]+=k; } void build(int l,int r,int p){ if(l==r){ tr[p]=w[l]; return ; } int mid=l+r>>1; build(l,mid,le(p)); build(mid+1,r,re(p)); push_up(p); } void r_add(int nl,int nr,int l,int r,int p,int k){ if(nl<=l&&nr>=r){ tr[p]+=1ll*(r-l+1)*k; tag[p]+=1ll*k; return ; } push_down(l,r,p,tag[p]),tag[p]=0; int mid=l+r>>1; if(nl<=mid) r_add(nl,nr,l,mid,le(p),k); if(nr>mid) r_add(nl,nr,mid+1,r,re(p),k); push_up(p); } ll r_query(int nl,int nr,int l,int r,int p){ ll ans=0; if(nl<=l&&nr>=r) return tr[p]; push_down(l,r,p,tag[p]),tag[p]=0; int mid=l+r>>1; if(nl<=mid) ans+=r_query(nl,nr,l,mid,le(p)); if(nr>mid) ans+=r_query(nl,nr,mid+1,r,re(p)); return ans; } void dfs1(int x){ sz[x]=1;int max_part=-1;vis[x]++; for(int i=head[x];i;i=edge[i].next){ int y=edge[i].to; if(y==fa[x]) continue; fa[y]=x;dep[y]=dep[x]+1; dfs1(y);sz[x]+=sz[y]; if(max_part<sz[y]) son[x]=y,max_part=sz[y]; } } void dfs2(int x,int t){ id[x]=++cnt;w[cnt]=a[x];top[x]=t; if(!son[x]) return ; dfs2(son[x],t); for(int i=head[x];i;i=edge[i].next){ int y=edge[i].to; if(y==son[x]||fa[x]==y) continue; dfs2(y,y); } } int Lca(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); x=fa[top[x]]; } return dep[x]>dep[y]?y:x; }//只要看懂树链剖分的基本操作,这个很简单 //前提条件:要求的节点相挨的节点u,必须是root的LCA int find(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下往上跳 if(fa[top[x]]==y) return top[x];//如果y是x所在重链top的父亲节点,那么就可以返回了 x=fa[top[x]];//跳 } if(dep[x]<dep[y]) swap(x,y);//让y最浅 return son[y];// 因为在一条重链上,那么重儿子一定是路径上与要求节点相挨的 } void tree_add(int x,int k){ if(root==x) r_add(1,n,1,n,1,k);//CASE 1 else{ int lca=Lca(x,root); if(lca!=x) r_add(id[x],id[x]+sz[x]-1,1,n,1,k);//CASE 2 else{ int dson=find(x,root); r_add(1,n,1,n,1,k); r_add(id[dson],id[dson]+sz[dson]-1,1,n,1,-k); }//CASE 3 } return ; } ll tree_query(int x){ if(root==x) return r_query(1,n,1,n,1);//CASE 1 else{ int lca=Lca(x,root); if(lca!=x) return r_query(id[x],id[x]+sz[x]-1,1,n,1);//CASE 2 else{ int dson=find(x,root); return r_query(1,n,1,n,1)-r_query(id[dson],id[dson]+sz[dson]-1,1,n,1); }//CASE 3 } } void road_add(int x,int y,ll k){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); r_add(id[top[x]],id[x],1,n,1,k); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); r_add(id[x],id[y],1,n,1,k); return ; } ll road_query(int x,int y){ ll ans=0; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); ans+=r_query(id[top[x]],id[x],1,n,1); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); ans+=r_query(id[x],id[y],1,n,1); return ans; } int main(){ // freopen("cin.in","r",stdin); // freopen("co.out","w",stdout); scan(n); for(int i=1;i<=n;i++) scan(a[i]); for(int i=2,v;i<=n;i++) scan(v),add(i,v),add(v,i); dfs1(1),dfs2(1,1),build(1,n,1);root=1; scan(m); for(int i=1;i<=m;i++){ int type,x,y,z; scan(type),scan(x); if(type==1) root=x; else if(type==2) scan(y),scan(z),road_add(x,y,z); else if(type==3) scan(z),tree_add(x,z); else if(type==4) scan(y),printf("%lld\n",road_query(x,y)); else if(type==5) printf("%lld\n",tree_query(x)); } return 0; }