树链剖分讲解
题目:Aragorn's Story
链接:http://acm.hdu.edu.cn/showproblem.php?pid=3966
题意:给一棵树,每个结点都有初始的权值,有m个操作,分两种:一是从x 结点到y 结点路上所有的结点权值+z或-z,二是问x结点的权值。
思路:
树链剖分。
这是我学树剖的第一题,建议还没接触过的伙伴,第一次学习的时候不要一直纠结理论,直接找一道模板题,然后找一篇AC代码,直接理解,做完一题后,你就会发现理论其实也挺好理解的,树剖也挺好学的。
树剖中的新概念:重孩子、轻孩子、重链、轻链,后面解释
fa[x]:x 结点的父结点
dep[x]:x 结点的深度
siz[x]:以x 结点为根的子树的结点个数
son[x]:x 结点的重孩子,即x 所有孩子中siz 最大的那个(相对的,其他的为轻孩子)
结点x 和他重孩子的连边叫作重边,重边组成的一条链叫作重链。
top[x]:x 所属的重链的头结点
树链剖分就是为了合理地安排每个结点在线段树中的位置,这里,属于同一条重链的结点将分配在一起。
pos[x]:x 结点在线段树中的位置
xd[x]:和pos相反,表示线段树中x 位置的结点是哪个
上面的数组定义给出后,有过基础的完全可以自己求出来了(下面AC代码的dfs1、dfs2),至于求出来什么用,再看下面。
比如现在要求:x 结点到y 结点路上所有结点的权值+1。
分两种:1. 如果x、y 两结点在同一条重链上,那么他们路上的结点其实就是线段树上连续的一个区间[x, y],那么像普通的线段树区间更新似的便可以解决。2. 如果x、y 不在同一个重链上,找到他们的链头,也就是top[x]、top[y],判断哪个深度大,选择深度大的那个,假设为x,现在我们可以更新区间[top[x],x],然后x指向top[x]的父结点,再次判断x、y是否同一重链。
单点查询就不说了,和线段树单点查询一样。
实现弄完了,我们来研究一下为什么他可以快速解决该类问题,从区间更新那里,我们可以看到,如果属于同一条重链,那么接下来的更新操作是线段树操作,这个时间复杂度是logn大家都学过了。也就是说如果有可能慢,那就慢在属于不同的重链,而且慢在必须一直跳(也就是说始终跳不到同一条重链),慢在重链的长度很短(一次只能跳一点)。如果始终没跳到一条重链上,那么跳的次数最多就是树的高度,那么会不会树的高度很大而一次跳很短呢,答案是否定的,因为结点x 的重孩子是其所有孩子结点中siz 最大的那个,如果要跳的y 在重孩子那棵子树,那么边x-son[x]是重边,是可以跳过的,如果要跳的y 在轻孩子z那棵子树,虽然x-z不是重边,是不能跳过去的,但重孩子至少分去了一半的结点,这样层层计算下来,最终时间复杂度也是logn,并不会出现跳的次数过多的情况。
AC代码:
1 #include<stdio.h> 2 #include<vector> 3 #include<algorithm> 4 using namespace std; 5 #define N 100010 6 #define lson rt<<1 7 #define rson rt<<1|1 8 int fa[N],dep[N],siz[N]; 9 int son[N]; 10 vector<int> e[N]; 11 void dfs1(int rt,int f,int h) 12 { 13 dep[rt]=h; fa[rt]=f; siz[rt]=1; 14 for(int i=0;i<e[rt].size();i++) 15 { 16 int ad=e[rt][i]; 17 if(ad!=f) 18 { 19 dfs1(ad,rt,h+1); 20 siz[rt]+=siz[ad]; 21 if(son[rt]==-1 || siz[ad]>siz[son[rt]]) 22 son[rt]=ad; 23 } 24 } 25 } 26 int top[N],pos[N],xd[N],po; 27 void dfs2(int rt,int org) 28 { 29 top[rt]=org; pos[rt]=po++; 30 xd[pos[rt]]=rt; 31 if(son[rt]==-1) return; 32 dfs2(son[rt],org); 33 for(int i=0;i<e[rt].size();i++) 34 { 35 int ad=e[rt][i]; 36 if(ad!=son[rt] && ad!=fa[rt]) 37 dfs2(ad,ad); 38 } 39 } 40 struct Node 41 { 42 int w,c; 43 int l,r; 44 int mid() 45 { 46 return (l+r)/2; 47 } 48 }; 49 Node v[N<<2]; 50 int num[N]; 51 52 void build(int l,int r,int rt) 53 { 54 v[rt].c=0; 55 v[rt].l=l; 56 v[rt].r=r; 57 if(l==r) 58 { 59 v[rt].w=num[xd[l]]; 60 return; 61 } 62 build(l,v[rt].mid(),lson); 63 build(v[rt].mid()+1,r,rson); 64 v[rt].w = v[lson].w+v[rson].w; 65 } 66 void update(int val,int l,int r,int rt) 67 { 68 if(l<=v[rt].l && v[rt].r<=r) 69 { 70 v[rt].c+=val; 71 v[rt].w+=val*(v[rt].r-v[rt].l+1); 72 return; 73 } 74 if(v[rt].c) 75 { 76 v[lson].c += v[rt].c; 77 v[rson].c += v[rt].c; 78 v[lson].w += (v[lson].r-v[lson].l+1)*v[rt].c; 79 v[rson].w += (v[rson].r-v[rson].l+1)*v[rt].c; 80 v[rt].c=0; 81 } 82 int mid=v[rt].mid(); 83 if(l<=mid) update(val,l,r,lson); 84 if(r>mid) update(val,l,r,rson); 85 v[rt].w = v[lson].w+v[rson].w; 86 } 87 void change(int x,int y,int val) 88 { 89 while(top[x]!=top[y]) 90 { 91 if(dep[top[x]]<dep[top[y]]) swap(x,y); 92 update(val,pos[top[x]],pos[x],1); 93 x=fa[top[x]]; 94 } 95 if(dep[x]>dep[y]) swap(x,y); 96 update(val,pos[x],pos[y],1); 97 } 98 int query(int rt,int val) 99 { 100 if(v[rt].l==v[rt].r) return v[rt].w; 101 102 if(v[rt].c) 103 { 104 v[lson].c += v[rt].c; 105 v[rson].c += v[rt].c; 106 v[lson].w += (v[lson].r-v[lson].l+1)*v[rt].c; 107 v[rson].w += (v[rson].r-v[rson].l+1)*v[rt].c; 108 v[rt].c=0; 109 } 110 int mid=v[rt].mid(); 111 int ret=0; 112 if(val<=mid) ret=query(lson,val); 113 else ret=query(rson,val); 114 v[rt].w = v[lson].w+v[rson].w; 115 return ret; 116 } 117 int main() 118 { 119 int n,m,x,y,z; 120 while(~scanf("%d%*d%d",&n,&m)) 121 { 122 po=1; 123 for(int i=1;i<=n;i++) 124 { 125 e[i].clear(); 126 son[i]=-1; 127 scanf("%d",&num[i]); 128 } 129 130 for(int i=1;i<n;i++) 131 { 132 scanf("%d%d",&x,&y); 133 e[x].push_back(y); 134 e[y].push_back(x); 135 } 136 dfs1(1,0,0); 137 dfs2(1,1); 138 build(1,n,1); 139 while(m--) 140 { 141 char s[10]; 142 scanf("%s",s); 143 if(s[0]=='Q') 144 { 145 scanf("%d",&x); 146 printf("%d\n",query(1,pos[x])); 147 } 148 else 149 { 150 scanf("%d%d%d",&x,&y,&z); 151 if(s[0]=='D') z=-z; 152 change(x,y,z); 153 } 154 } 155 } 156 return 0; 157 }