树链剖分 学习整理
“在一棵树上进行路径的修改、求极值、求和”乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。
树链,就是树上的路径。剖分,就是把路径分类为重链和轻链。
记siz[v]表示以v为根的子树的节点数,dep[v]表示v的深度(根深度为1),top[v]表示v所在的链的顶端节点,fa[v]表示v的父亲,son[v]表示与v在同一重链上的v的儿子节点(姑且称为重儿子),w[v]表示v与其父亲节点的连边(姑且称为v的父边)在线段树中的位置。只要把这些东西求出来,就能用logn的时间完成原问题中的操作。
重儿子:siz[u]为v的子节点中siz值最大的,那么u就是v的重儿子。
轻儿子:v的其它子节点。
重边:点v与其重儿子的连边。
轻边:点v与其轻儿子的连边。
重链:由重边连成的路径。
轻链:轻边。
剖分后的树有如下性质:
性质1:如果(v,u)为轻边,则siz[u] * 2 < siz[v];
性质2:从根到某一点的路径上轻链、重链的个数都不大于logn。
算法实现:
我们可以用两个dfs来求出fa、dep、siz、son、top、w。
第一遍dfs:把fa、dep、siz、son求出来,比较简单,略过。
第二遍dfs:⒈对于v,当son[v]存在(即v不是叶子节点)时,显然有top[son[v]] = top[v]。线段树中,v的重边应当在v的父边的后面,记w[son[v]] = totw+1,totw表示最后加入的一条边在线段树中的位置。此时,为了使一条重链各边在线段树中连续分布,应当进行dfs_2(son[v]);
⒉对于v的各个轻儿子u,显然有top[u] = u,并且w[u] = totw+1,进行dfs_2过程。
这就求出了top和w。
将树中各边的权值在线段树中更新,建链和建线段树的过程就完成了。
修改操作:例如将u到v的路径上每条边的权值都加上某值x。
一般人需要先求LCA,然后慢慢修改u、v到公共祖先的边。而高手就不需要了。
记f1 = top[u],f2 = top[v]。
当f1 <> f2时:不妨设dep[f1] >= dep[f2],那么就更新u到f1的父边的权值(logn),并使u = fa[f1]。
当f1 = f2时:u与v在同一条重链上,若u与v不是同一点,就更新u到v路径上的边的权值(logn),否则修改完成;
重复上述过程,直到修改完成。
求和、求极值操作:类似修改操作,但是不更新边权,而是对其求和、求极值。
就这样,原问题就解决了。鉴于鄙人语言表达能力有限,咱画图来看看:
如右图所示,较粗的为重边,较细的为轻边。节点编号旁边有个红色点的表明该节点是其所在链的顶端节点。边旁的蓝色数字表示该边在线段树中的位置。图中1-4-9-13-14为一条重链。
当要修改11到10的路径时。
第一次迭代:u = 11,v = 10,f1 = 2,f2 = 10。此时dep[f1] < dep[f2],因此修改线段树中的5号点,v = 4, f2 = 1;
第二次迭代:dep[f1] > dep[f2],修改线段树中10--11号点。u = 2,f1 = 2;
第三次迭代:dep[f1] > dep[f2],修改线段树中9号点。u = 1,f1 = 1;
第四次迭代:f1 = f2且u = v,修改结束。
例题 : SPOJ 375 的代码
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 using namespace std; 5 int n,m; 6 struct node { 7 int l,r,sum; 8 }tree[100010*4]; 9 void built(int l,int r,int k) 10 { 11 tree[k].l=l;tree[k].r=r; 12 if(l==r){ scanf("%d",&tree[k].sum);return ; } 13 int mid=(l+r)/2; 14 built(l,mid,k*2);built(mid+1,r,k*2+1); 15 tree[k].sum=tree[k*2].sum+tree[k*2+1].sum; 16 } 17 void change(int k,int pos,int x) 18 { 19 int l=tree[k].l,r=tree[k].r; 20 if(l==r){ tree[k].sum+=x;return; } 21 int mid=(l+r)/2; 22 if(pos<=mid) change(k*2,pos,x); 23 if(pos>mid) change(k*2+1,pos,x); 24 tree[k].sum=tree[k*2].sum+tree[k*2+1].sum; 25 } 26 int query(int k,int l,int r)// 区间查询(以求和为例) 27 { 28 int ans=0; 29 if(l==tree[k].l&&r==tree[k].r) { return tree[k].sum; } 30 int mid=(tree[k].l+tree[k].r)/2; 31 if(l<=mid) ans+=query(k*2,l,min(mid,r)); 32 if(r>mid) ans+=query(k*2+1,max(mid+1,l),r); 33 return ans; 34 } 35 int find(int k,int pos) 36 { 37 if(tree[k].l==tree[k].r) { return tree[k].sum; } 38 int mid=(tree[k].l+tree[k].r)/2; 39 if(pos<=mid) find(k*2,pos); 40 if(pos>mid) find(k*2+1,pos); 41 } 42 void allchange(int k,int ls,int rs,int x) 43 { 44 int l=tree[k].l,r=tree[k].r; 45 if(l==r){tree[k].sum+=x;return;} 46 int mid=(l+r)/2; 47 if(ls<=mid) allchange(k*2,ls,min(rs,mid),x); 48 if(rs>mid) allchange(k*2+1,max(ls,mid+1),rs,x); 49 } 50 int main()// 线段树 维护 区间求和 和 单点修改 51 { 52 scanf("%d",&n); 53 built(1,n,1); 54 change(1,x,a);// 在x的为位置上增加a 55 query(1,x,y); 56 find(1,x);// 单点查询x 57 allchange(1,x,y,z);//区间x到y 的值全部增加 z 58 return 0; 59 }
未完成代码存档:
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #define clr(a,b) memset(a,b,sizeof(a)) 5 using namespace std; 6 const int maxn=10010; 7 struct node{ 8 int from,to,value,next; 9 }e[5005*4]; 10 struct tre{ 11 int l,r,value; 12 }tree[maxn*4]; 13 int n,pos[maxn],dep[maxn],head[maxn],m,T,p; 14 int fa[maxn],siz[maxn],son[maxn],top[maxn],nid; 15 void add(int from,int to,int value){ 16 m++; 17 e[m].from=from;e[m].to=to;e[m].value=value;e[m].next=head[from];head[from]=m; 18 } 19 void clear(){ 20 p=0;m=0;clr(head,0);clr(dep,0);clr(siz,0);clr(son,0);clr(e,0); 21 for(int i=1,x,y,z;i<=n-1;i++){ 22 scanf("%d%d%d",&x,&y,&z); 23 add(x,y,z);add(y,x,z); 24 } 25 } 26 void dfs_1(int s,int fu,int deepth){// 父节点 深度 重孩子 27 fa[s]=fu;dep[s]=deepth;son[s]=-1;siz[s]=1; 28 for(int i=head[s];i;i=e[i].next){ 29 int to=e[i].to;if(to==fu) continue; 30 dfs_1(to,s,deepth+1);siz[s]+=siz[to]; 31 if(!son[s]||siz[son[s]]<siz[to]) son[s]=to; 32 } 33 } 34 void gettop(int s,int f){// 链顶节点 线段树中的位置 35 top[s]=f;pos[s]=++p; 36 if(!son[s]) return; 37 gettop(son[s],f); 38 for(int i=head[s];i;i=e[i].next){ 39 int v=e[i].to; 40 if(v!=son[s]&&v!=fa[s]) gettop(v,v); 41 } 42 } 43 void built(int l,int r,int k){// 第一步是建一颗 空树 44 tree[k].l=l;tree[k].r=r; 45 if(l==r) { return; } 46 int mid=(l+r)/2; 47 built(l,mid,k*2);built(mid+1,r,k*2+1); 48 } 49 void update(int k,int ps,int val){ 50 if(tree[k].l==tree[k].r) 51 {tree[k].value=val;return; } 52 int mid=(tree[k].l+tree[k].r)/2; 53 if(ps<=mid) update(k*2,ps,val); 54 else update(k*2+1,ps,val); 55 tree[k].value=max(tree[k*2].value,tree[k*2+1].value); 56 } 57 int query(int k,int l,int r) 58 { 59 if(tree[k].l==tree[k].r)return tree[k].value; 60 int mid=(tree[k].l+tree[k].r)/2; 61 int ans=0; 62 if(l<=mid)ans=max(ans,query(p*2,l,r)); 63 if(r>mid)ans=max(ans,query(p*2+1,l,r)); 64 return ans; 65 } 66 int find(int u,int v){ 67 int t1=top[u],t2=top[v],ans=0; 68 while(t1!=t2){ 69 if(dep[t1]<dep[t2]){ swap(t1,t2);swap(u,v); } 70 ans=max(ans,query(1,pos[t1],pos[t2])); 71 u=fa[t1];t1=top[u]; 72 } 73 if(u==v) return ans; 74 if(dep[u]>dep[v]) swap(u,v); 75 return max(ans,query(1,pos[u]+1,pos[v])); 76 } 77 int main() 78 { 79 scanf("%d",&T); 80 while(T--) 81 { 82 clear(); 83 dfs_1(1,0,1); 84 gettop(1,1); 85 built(1,p,1); 86 for(int i=1;i<=2*n-2;i+=2){ 87 if(dep[e[i].to]<dep[e[i].from]) swap(e[i].to,e[i].from); 88 update(1,pos[e[i].to],e[i].value); 89 } 90 char s[15];int u,v; 91 while(scanf("%s",s)==1) 92 { 93 if(s[0]=='D') break; 94 if(s[0]=='Q') printf("%d\n",find(u,v)); 95 else update(1,pos[e[u*2-1].to],v); 96 } 97 } 98 return 0; 99 }
备注:引用自网络