树链剖分模板

概念

树剖就是将一棵树暴力拆成几条链,然后对于这样一个序列,我们就可以套上资瓷区间处理的一些东西qwq(比如说线段树,树状数组

可以解决的问题:

  • 将树从$x$到$y$结点最短路径上所有节点的值都加上$z$
  • 求树从$x$到$y$结点最短路径上所有节点的值之和/最大值
  • 将以$x$为根节点的子树内所有节点值都加上$z$
  • 求以$x$为根节点的子树内所有节点值之和/最大值

一些概念:

  • 重儿子:父亲节点的所有儿子中子树结点数目最多($size$最大)的结点;
  • 轻儿子:父亲节点中除了重儿子以外的儿子;
  • 重边:父亲结点和重儿子连成的边;
  • 轻边:父亲节点和轻儿子连成的边;
  • 重链:由多条重边连接而成的路径;
  • 轻链:由多条轻边连接而成的路径;

实现

一些定义:

  • $f(x)$表示节点$x$在树上的父亲
  • $dep(x)$表示节点$x$在树上的深度
  • $siz(x)$表示节点$x$的子树的节点的个数
  • $son(x)$表示节点$x$的重儿子
  • $top(x)$表示节点$s$所在重链的顶部节点(深度最小)
  • $id(x)$表示节点$x$在线段树中的编号
  • $rk(x)$表示线段树中标号为$x$的节点对应的树上节点的编号

1、第一次DFS,对于一个点求出它所在的子树的大小、它的重儿子,顺便记录其父节点和深度。

 1 void dfs1(int u, int fa, int depth)  //当前节点、父节点、层次深度
 2 {
 3     //printf("u:%d fa:%d depth:%d\n", u, fa, depth);
 4     f[u] = fa;
 5     deep[u] = depth;
 6     size[u] = 1;   //这个点本身的size
 7     for(int i = head[u];i;i = edges[i].next)
 8     {
 9         int v = edges[i].to;
10         if(v == fa)  continue;
11         dfs1(v, u, depth+1);
12         size[u] += size[v];   //子节点的size已被处理,用它来更新父节点的size
13         if(size[v] > size[son[u]])  son[u] = v;    //选取size最大的作为重儿子
14     }
15 }

2、第二次DFS,连接重链,同时标记每个节点的DFS序。为了用数据结构来维护重链,我们在DFS时保证一条重链上的节点DFS序连续。一个节点的子树内DFS序也连续。

 1 void dfs2(int u, int t)  //当前节点、重链顶端
 2 {
 3     printf("u:%d t:%d\n", u, t);
 4     top[u] = t;
 5     id[u] = ++cnt;   //标记dfs序
 6     rk[cnt] = u;     //序号cnt对应节点u
 7     if(!son[u])  return;   //没有儿子?
 8     dfs2(son[u], t);  //我们选择优先进入重儿子来保证一条重链上各个节点dfs序连续
 9 
10     for(int i = head[u];i;i = edges[i].next)
11     {
12         int v = edges[i].to;
13         if(v != son[u] && v != f[u])  dfs2(v, v);  //这个点位于轻链顶端,那么它的top必然为它本身
14     }
15 }

3、两遍DFS就是树链剖分的主要处理,通过dfs我们已经保证一条重链上各个节点的dfs序连续,那么可以想到,我们可以通过数据结构来维护(以线段树为例)来维护一条重链的信息。

维护和

 1 ll querysum(int x, int y)
 2 {
 3     int fx = top[x], fy = top[y];
 4     ll ans = 0;
 5     while(fx != fy)   //当两者不在同一条重链上
 6     {
 7         if(deep[fx] >= deep[fy])
 8         {
 9             ans += st.query2(1, 1, n, 0, id[fx], id[x]);   //线段树区间求和,计算这条重链的贡献
10             x = f[fx]; fx = top[x];
11         }
12         else
13         {
14             ans += st.query2(1, 1, n, 0, id[fy], id[y]);
15             y = f[fy]; fy = top[y];
16         }
17     }
18 
19     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
20     if(id[x] <= id[y])
21     {
22         ans += st.query2(1, 1, n, 0, id[x], id[y]);
23     }
24     else
25     {
26         ans += st.query2(1, 1, n, 0, id[y], id[x]);
27     }
28     return ans;
29 }

维护最大值

 1 ll querymax(int x, int y)
 2 {
 3     int fx = top[x], fy = top[y];
 4     ll ans = -INF;
 5     while(fx != fy)   //当两者不在同一条重链上
 6     {
 7         if(deep[fx] >= deep[fy])
 8         {
 9             ans = max(ans, st.query1(1, 1, n, 0, id[fx], id[x]));   //线段树区间求和,计算这条重链的贡献
10             x = f[fx]; fx = top[x];
11         }
12         else
13         {
14             ans = max(ans, st.query1(1, 1, n, 0, id[fy], id[y]));
15             y = f[fy]; fy = top[y];
16         }
17     }
18     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
19     if(id[x] <= id[y])  ans = max(ans, st.query1(1, 1, n, 0, id[x], id[y]));
20     else ans = max(ans, st.query1(1, 1, n, 0, id[y], id[x]));
21     return ans;
22 }

时间复杂度

对于每次询问,最多经过$O(log\ n)$条重链,每条重链上线段树的复杂度为$O(log \ n)$,此时总的时间复杂度为$O(nlogn+q{log}^2n)$。实际上重链个数很难达到$O(log \ n)$(可以用完全二叉树卡满),所以树剖在一般情况下常数较小。

完整的代码

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 
  4 typedef long long ll;
  5 #define lc o <<1
  6 #define rc o <<1 | 1
  7 const int INF = 0x3f3f3f3f;
  8 const int maxn = 100000 + 10;
  9 struct Edge
 10 {
 11     int to, next;
 12 }edges[2*maxn];
 13 int head[maxn];
 14 int cur, f[maxn], deep[maxn], size[maxn], son[maxn], rk[maxn], id[maxn], top[maxn], cnt;
 15 int n, root, qcnt, w[maxn];
 16 
 17 inline void addedge(int u, int v)
 18 {
 19     ++cur;
 20     edges[cur].next = head[u];
 21     head[u] = cur;
 22     edges[cur].to = v;
 23 }
 24 
 25 struct SegTree{
 26     ll sum[maxn << 2], maxv[maxn << 2], addv[maxn << 2];
 27     void build(int o, int l, int r)
 28     {
 29         if(l == r)
 30         {
 31             sum[o] = maxv[o] = w[rk[l]];
 32         }
 33         else
 34         {
 35             int mid = (l + r) >> 1;
 36             build(lc, l, mid);
 37             build(rc, mid+1, r);
 38             sum[o] = sum[lc] + sum[rc];
 39             maxv[o] = max(maxv[lc], maxv[rc]);
 40         }
 41     }
 42 
 43     void maintain(int o, int l, int r)
 44     {
 45         if(l == r)  //如果是叶子结点
 46         {
 47             maxv[o] = w[rk[l]];
 48             sum[o] = w[rk[l]];
 49         }
 50         else     //如果是非叶子结点
 51         {
 52             maxv[o] = max(maxv[lc], maxv[rc]);
 53             sum[o] = sum[lc] + sum[rc];
 54         }
 55         maxv[o] += addv[o];     //考虑add操作
 56         sum[o] += addv[o] * (r-l+1);
 57     }
 58     //区间修改,[cl,cr] += v;
 59     void update(int o, int l, int r, int cl, int cr, int v)  //
 60     {
 61         //printf("o:%d  l:%d  r:%d\n", o, l, r);
 62         if(cl <= l && r <= cr)  addv[o] += v;
 63         else
 64         {
 65             int m = l + (r-l) /2;
 66             if(cl <= m)  update(lc, l, m, cl, cr, v);
 67             if(cr > m)  update(rc, m+1, r, cl, cr, v);
 68         }
 69         maintain(o, l, r);
 70     }
 71 
 72     //区间查询1,max{ql,qr}
 73     ll query1(int o, int l,int r, int add, int ql, int qr)
 74     {
 75         //prllf("o:%d l:%d r:%d\n", o, l, r);
 76         if(ql <= l && r <= qr)  return maxv[o] + add;
 77         else
 78         {
 79             int m = l + (r - l) / 2;
 80             ll ans = -INF;
 81             add += addv[o];
 82             if(ql <= m)  ans = max(ans, query1(lc, l, m, add, ql, qr));
 83             if(qr > m)  ans = max(ans, query1(rc, m+1, r, add, ql, qr));
 84             return ans;
 85         }
 86     }
 87 
 88     //区间查询2,sum{ql,qr}
 89     ll query2(int o, int l,int r, int add, int ql, int qr)
 90     {
 91         //prllf("o:%d l:%d r:%d ql:%d qr:%d\n", o, l, r, ql, qr);
 92         if(ql <= l && r <= qr)  return sum[o] + add * (r-l+1);
 93         else
 94         {
 95             int m = l + (r - l) / 2;
 96             ll ans = 0;
 97             add += addv[o];
 98             if(ql <= m)  ans += query2(lc, l, m, add, ql, qr);
 99             if(qr > m)  ans += query2(rc, m+1, r, add, ql, qr);
100             return ans;
101         }
102     }
103 }st;
104 
105 void dfs1(int u, int fa, int depth)  //当前节点、父节点、层次深度
106 {
107     //printf("u:%d fa:%d depth:%d\n", u, fa, depth);
108     f[u] = fa;
109     deep[u] = depth;
110     size[u] = 1;   //这个点本身的size
111     for(int i = head[u];i;i = edges[i].next)
112     {
113         int v = edges[i].to;
114         if(v == fa)  continue;
115         dfs1(v, u, depth+1);
116         size[u] += size[v];   //子节点的size已被处理,用它来更新父节点的size
117         if(size[v] > size[son[u]])  son[u] = v;    //选取size最大的作为重儿子
118     }
119 }
120 
121 void dfs2(int u, int t)  //当前节点、重链顶端
122 {
123     printf("u:%d t:%d\n", u, t);
124     top[u] = t;
125     id[u] = ++cnt;   //标记dfs序
126     rk[cnt] = u;     //序号cnt对应节点u
127     if(!son[u])  return;   //没有儿子?
128     dfs2(son[u], t);  //我们选择优先进入重儿子来保证一条重链上各个节点dfs序连续
129 
130     for(int i = head[u];i;i = edges[i].next)
131     {
132         int v = edges[i].to;
133         if(v != son[u] && v != f[u])  dfs2(v, v);  //这个点位于轻链顶端,那么它的top必然为它本身
134     }
135 }
136 
137 ll querymax(int x, int y)
138 {
139     int fx = top[x], fy = top[y];
140     ll ans = -INF;
141     while(fx != fy)   //当两者不在同一条重链上
142     {
143         if(deep[fx] >= deep[fy])
144         {
145             ans = max(ans, st.query1(1, 1, n, 0, id[fx], id[x]));   //线段树区间求和,计算这条重链的贡献
146             x = f[fx]; fx = top[x];
147         }
148         else
149         {
150             ans = max(ans, st.query1(1, 1, n, 0, id[fy], id[y]));
151             y = f[fy]; fy = top[y];
152         }
153     }
154     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
155     if(id[x] <= id[y])  ans = max(ans, st.query1(1, 1, n, 0, id[x], id[y]));
156     else ans = max(ans, st.query1(1, 1, n, 0, id[y], id[x]));
157     return ans;
158 }
159 
160 /*修改和查询的原理是一致的,以查询操作为例,其实就是个LCA,不过这里要使用top数组加速,因为top可以直接跳到该重链的起始顶点*/
161 /*注意,每次循环只能跳一次,并且让结点深的那个跳到top的位置,避免两者一起跳而插肩而过*/
162 ll querysum(int x, int y)
163 {
164     int fx = top[x], fy = top[y];
165     ll ans = 0;
166     while(fx != fy)   //当两者不在同一条重链上
167     {
168         if(deep[fx] >= deep[fy])
169         {
170             ans += st.query2(1, 1, n, 0, id[fx], id[x]);   //线段树区间求和,计算这条重链的贡献
171             x = f[fx]; fx = top[x];
172         }
173         else
174         {
175             ans += st.query2(1, 1, n, 0, id[fy], id[y]);
176             y = f[fy]; fy = top[y];
177         }
178     }
179 
180     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
181     if(id[x] <= id[y])
182     {
183         ans += st.query2(1, 1, n, 0, id[x], id[y]);
184     }
185     else
186     {
187         ans += st.query2(1, 1, n, 0, id[y], id[x]);
188     }
189     return ans;
190 }
191 
192 void update_add(int x, int y, int add)
193 {
194     int fx = top[x], fy = top[y];
195     while(fx != fy)   //当两者不在同一条重链上
196     {
197         if(deep[fx] >= deep[fy])
198         {
199             st.update(1, 1, n, id[fx], id[x], add);
200             x = f[fx]; fx = top[x];
201         }
202         else
203         {
204             st.update(1, 1, n, id[fy], id[y], add);
205             y = f[fy]; fy = top[y];
206         }
207     }
208     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
209     if(id[x] <= id[y])  st.update(1, 1, n, id[x], id[y], add);
210     else  st.update(1, 1, n, id[y], id[x], add);
211 }
212 
213 int main()
214 {
215     scanf("%d%d%d", &n, &root, &qcnt);
216     for(int i = 1;i <= n;i++)  scanf("%d", &w[i]);
217     for(int i = 1;i < n;i++)
218     {
219         int u, v;
220         scanf("%d%d", &u, &v);
221         addedge(u, v);
222         addedge(v, u);
223     }
224     dfs1(root, -1, 1);
225     dfs2(root, root);
226 
227     for(int i = 1;i <= n;i++)  printf("%d  ", id[i]);
228     printf("\n");
229     for(int i = 1;i <= n;i++)  printf("%d  ", rk[i]);
230     printf("\n");
231 
232     st.build(1, 1, n);
233 
234     while(qcnt--)
235     {
236         int op;
237         scanf("%d", &op);
238         if(op == 1)
239         {
240             int u, v, add;
241             scanf("%d%d%d", &u, &v, &add);
242             update_add(u, v,  add);
243         }
244         else if(op == 2)
245         {
246             int u, v;
247             scanf("%d%d", &u, &v);
248             printf("%d\n", querymax(u, v));
249         }
250         else if(op == 3)
251         {
252             int u, v;
253             scanf("%d%d", &u, &v);
254             printf("%d\n", querysum(u, v));
255         }
256         else if(op == 4)
257         {
258             int u, add;
259             scanf("%d%d", &u, &add);
260             st.update(1, 1, n, id[u], id[u]+size[u]-1, add);
261         }
262         else if(op == 5)
263         {
264             int u;
265             scanf("%d", &u);
266             printf("%d\n",st.query1(1, 1, n, 0, id[u], id[u]+size[u]-1));
267         }
268         else
269         {
270             int u;
271             scanf("%d", &u);
272             printf("%d\n",st.query2(1, 1, n, 0, id[u], id[u]+size[u]-1));
273         }
274     }
275     return 0;
276 }
View Code

 

 

参考链接:

1. https://oi-wiki.org/graph/heavy-light-decomposition/

2. https://zhuanlan.zhihu.com/p/41082337

3. https://www.luogu.org/problemnew/solution/P3384

 

posted @ 2019-07-11 12:28  Rogn  阅读(410)  评论(0编辑  收藏  举报