树链剖分模板
概念
树剖就是将一棵树暴力拆成几条链,然后对于这样一个序列,我们就可以套上资瓷区间处理的一些东西qwq(比如说线段树,树状数组
可以解决的问题:
- 将树从$x$到$y$结点最短路径上所有节点的值都加上$z$
- 求树从$x$到$y$结点最短路径上所有节点的值之和/最大值
- 将以$x$为根节点的子树内所有节点值都加上$z$
- 求以$x$为根节点的子树内所有节点值之和/最大值
一些概念:
- 重儿子:父亲节点的所有儿子中子树结点数目最多($size$最大)的结点;
- 轻儿子:父亲节点中除了重儿子以外的儿子;
- 重边:父亲结点和重儿子连成的边;
- 轻边:父亲节点和轻儿子连成的边;
- 重链:由多条重边连接而成的路径;
- 轻链:由多条轻边连接而成的路径;
实现
一些定义:
- $f(x)$表示节点$x$在树上的父亲
- $dep(x)$表示节点$x$在树上的深度
- $siz(x)$表示节点$x$的子树的节点的个数
- $son(x)$表示节点$x$的重儿子
- $top(x)$表示节点$s$所在重链的顶部节点(深度最小)
- $id(x)$表示节点$x$在线段树中的编号
- $rk(x)$表示线段树中标号为$x$的节点对应的树上节点的编号
1、第一次DFS,对于一个点求出它所在的子树的大小、它的重儿子,顺便记录其父节点和深度。
1 void dfs1(int u, int fa, int depth) //当前节点、父节点、层次深度 2 { 3 //printf("u:%d fa:%d depth:%d\n", u, fa, depth); 4 f[u] = fa; 5 deep[u] = depth; 6 size[u] = 1; //这个点本身的size 7 for(int i = head[u];i;i = edges[i].next) 8 { 9 int v = edges[i].to; 10 if(v == fa) continue; 11 dfs1(v, u, depth+1); 12 size[u] += size[v]; //子节点的size已被处理,用它来更新父节点的size 13 if(size[v] > size[son[u]]) son[u] = v; //选取size最大的作为重儿子 14 } 15 }
2、第二次DFS,连接重链,同时标记每个节点的DFS序。为了用数据结构来维护重链,我们在DFS时保证一条重链上的节点DFS序连续。一个节点的子树内DFS序也连续。
1 void dfs2(int u, int t) //当前节点、重链顶端 2 { 3 printf("u:%d t:%d\n", u, t); 4 top[u] = t; 5 id[u] = ++cnt; //标记dfs序 6 rk[cnt] = u; //序号cnt对应节点u 7 if(!son[u]) return; //没有儿子? 8 dfs2(son[u], t); //我们选择优先进入重儿子来保证一条重链上各个节点dfs序连续 9 10 for(int i = head[u];i;i = edges[i].next) 11 { 12 int v = edges[i].to; 13 if(v != son[u] && v != f[u]) dfs2(v, v); //这个点位于轻链顶端,那么它的top必然为它本身 14 } 15 }
3、两遍DFS就是树链剖分的主要处理,通过dfs我们已经保证一条重链上各个节点的dfs序连续,那么可以想到,我们可以通过数据结构来维护(以线段树为例)来维护一条重链的信息。
维护和
1 ll querysum(int x, int y) 2 { 3 int fx = top[x], fy = top[y]; 4 ll ans = 0; 5 while(fx != fy) //当两者不在同一条重链上 6 { 7 if(deep[fx] >= deep[fy]) 8 { 9 ans += st.query2(1, 1, n, 0, id[fx], id[x]); //线段树区间求和,计算这条重链的贡献 10 x = f[fx]; fx = top[x]; 11 } 12 else 13 { 14 ans += st.query2(1, 1, n, 0, id[fy], id[y]); 15 y = f[fy]; fy = top[y]; 16 } 17 } 18 19 //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献 20 if(id[x] <= id[y]) 21 { 22 ans += st.query2(1, 1, n, 0, id[x], id[y]); 23 } 24 else 25 { 26 ans += st.query2(1, 1, n, 0, id[y], id[x]); 27 } 28 return ans; 29 }
维护最大值
1 ll querymax(int x, int y) 2 { 3 int fx = top[x], fy = top[y]; 4 ll ans = -INF; 5 while(fx != fy) //当两者不在同一条重链上 6 { 7 if(deep[fx] >= deep[fy]) 8 { 9 ans = max(ans, st.query1(1, 1, n, 0, id[fx], id[x])); //线段树区间求和,计算这条重链的贡献 10 x = f[fx]; fx = top[x]; 11 } 12 else 13 { 14 ans = max(ans, st.query1(1, 1, n, 0, id[fy], id[y])); 15 y = f[fy]; fy = top[y]; 16 } 17 } 18 //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献 19 if(id[x] <= id[y]) ans = max(ans, st.query1(1, 1, n, 0, id[x], id[y])); 20 else ans = max(ans, st.query1(1, 1, n, 0, id[y], id[x])); 21 return ans; 22 }
时间复杂度
对于每次询问,最多经过$O(log\ n)$条重链,每条重链上线段树的复杂度为$O(log \ n)$,此时总的时间复杂度为$O(nlogn+q{log}^2n)$。实际上重链个数很难达到$O(log \ n)$(可以用完全二叉树卡满),所以树剖在一般情况下常数较小。
完整的代码
1 #include<bits/stdc++.h> 2 using namespace std; 3 4 typedef long long ll; 5 #define lc o <<1 6 #define rc o <<1 | 1 7 const int INF = 0x3f3f3f3f; 8 const int maxn = 100000 + 10; 9 struct Edge 10 { 11 int to, next; 12 }edges[2*maxn]; 13 int head[maxn]; 14 int cur, f[maxn], deep[maxn], size[maxn], son[maxn], rk[maxn], id[maxn], top[maxn], cnt; 15 int n, root, qcnt, w[maxn]; 16 17 inline void addedge(int u, int v) 18 { 19 ++cur; 20 edges[cur].next = head[u]; 21 head[u] = cur; 22 edges[cur].to = v; 23 } 24 25 struct SegTree{ 26 ll sum[maxn << 2], maxv[maxn << 2], addv[maxn << 2]; 27 void build(int o, int l, int r) 28 { 29 if(l == r) 30 { 31 sum[o] = maxv[o] = w[rk[l]]; 32 } 33 else 34 { 35 int mid = (l + r) >> 1; 36 build(lc, l, mid); 37 build(rc, mid+1, r); 38 sum[o] = sum[lc] + sum[rc]; 39 maxv[o] = max(maxv[lc], maxv[rc]); 40 } 41 } 42 43 void maintain(int o, int l, int r) 44 { 45 if(l == r) //如果是叶子结点 46 { 47 maxv[o] = w[rk[l]]; 48 sum[o] = w[rk[l]]; 49 } 50 else //如果是非叶子结点 51 { 52 maxv[o] = max(maxv[lc], maxv[rc]); 53 sum[o] = sum[lc] + sum[rc]; 54 } 55 maxv[o] += addv[o]; //考虑add操作 56 sum[o] += addv[o] * (r-l+1); 57 } 58 //区间修改,[cl,cr] += v; 59 void update(int o, int l, int r, int cl, int cr, int v) // 60 { 61 //printf("o:%d l:%d r:%d\n", o, l, r); 62 if(cl <= l && r <= cr) addv[o] += v; 63 else 64 { 65 int m = l + (r-l) /2; 66 if(cl <= m) update(lc, l, m, cl, cr, v); 67 if(cr > m) update(rc, m+1, r, cl, cr, v); 68 } 69 maintain(o, l, r); 70 } 71 72 //区间查询1,max{ql,qr} 73 ll query1(int o, int l,int r, int add, int ql, int qr) 74 { 75 //prllf("o:%d l:%d r:%d\n", o, l, r); 76 if(ql <= l && r <= qr) return maxv[o] + add; 77 else 78 { 79 int m = l + (r - l) / 2; 80 ll ans = -INF; 81 add += addv[o]; 82 if(ql <= m) ans = max(ans, query1(lc, l, m, add, ql, qr)); 83 if(qr > m) ans = max(ans, query1(rc, m+1, r, add, ql, qr)); 84 return ans; 85 } 86 } 87 88 //区间查询2,sum{ql,qr} 89 ll query2(int o, int l,int r, int add, int ql, int qr) 90 { 91 //prllf("o:%d l:%d r:%d ql:%d qr:%d\n", o, l, r, ql, qr); 92 if(ql <= l && r <= qr) return sum[o] + add * (r-l+1); 93 else 94 { 95 int m = l + (r - l) / 2; 96 ll ans = 0; 97 add += addv[o]; 98 if(ql <= m) ans += query2(lc, l, m, add, ql, qr); 99 if(qr > m) ans += query2(rc, m+1, r, add, ql, qr); 100 return ans; 101 } 102 } 103 }st; 104 105 void dfs1(int u, int fa, int depth) //当前节点、父节点、层次深度 106 { 107 //printf("u:%d fa:%d depth:%d\n", u, fa, depth); 108 f[u] = fa; 109 deep[u] = depth; 110 size[u] = 1; //这个点本身的size 111 for(int i = head[u];i;i = edges[i].next) 112 { 113 int v = edges[i].to; 114 if(v == fa) continue; 115 dfs1(v, u, depth+1); 116 size[u] += size[v]; //子节点的size已被处理,用它来更新父节点的size 117 if(size[v] > size[son[u]]) son[u] = v; //选取size最大的作为重儿子 118 } 119 } 120 121 void dfs2(int u, int t) //当前节点、重链顶端 122 { 123 printf("u:%d t:%d\n", u, t); 124 top[u] = t; 125 id[u] = ++cnt; //标记dfs序 126 rk[cnt] = u; //序号cnt对应节点u 127 if(!son[u]) return; //没有儿子? 128 dfs2(son[u], t); //我们选择优先进入重儿子来保证一条重链上各个节点dfs序连续 129 130 for(int i = head[u];i;i = edges[i].next) 131 { 132 int v = edges[i].to; 133 if(v != son[u] && v != f[u]) dfs2(v, v); //这个点位于轻链顶端,那么它的top必然为它本身 134 } 135 } 136 137 ll querymax(int x, int y) 138 { 139 int fx = top[x], fy = top[y]; 140 ll ans = -INF; 141 while(fx != fy) //当两者不在同一条重链上 142 { 143 if(deep[fx] >= deep[fy]) 144 { 145 ans = max(ans, st.query1(1, 1, n, 0, id[fx], id[x])); //线段树区间求和,计算这条重链的贡献 146 x = f[fx]; fx = top[x]; 147 } 148 else 149 { 150 ans = max(ans, st.query1(1, 1, n, 0, id[fy], id[y])); 151 y = f[fy]; fy = top[y]; 152 } 153 } 154 //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献 155 if(id[x] <= id[y]) ans = max(ans, st.query1(1, 1, n, 0, id[x], id[y])); 156 else ans = max(ans, st.query1(1, 1, n, 0, id[y], id[x])); 157 return ans; 158 } 159 160 /*修改和查询的原理是一致的,以查询操作为例,其实就是个LCA,不过这里要使用top数组加速,因为top可以直接跳到该重链的起始顶点*/ 161 /*注意,每次循环只能跳一次,并且让结点深的那个跳到top的位置,避免两者一起跳而插肩而过*/ 162 ll querysum(int x, int y) 163 { 164 int fx = top[x], fy = top[y]; 165 ll ans = 0; 166 while(fx != fy) //当两者不在同一条重链上 167 { 168 if(deep[fx] >= deep[fy]) 169 { 170 ans += st.query2(1, 1, n, 0, id[fx], id[x]); //线段树区间求和,计算这条重链的贡献 171 x = f[fx]; fx = top[x]; 172 } 173 else 174 { 175 ans += st.query2(1, 1, n, 0, id[fy], id[y]); 176 y = f[fy]; fy = top[y]; 177 } 178 } 179 180 //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献 181 if(id[x] <= id[y]) 182 { 183 ans += st.query2(1, 1, n, 0, id[x], id[y]); 184 } 185 else 186 { 187 ans += st.query2(1, 1, n, 0, id[y], id[x]); 188 } 189 return ans; 190 } 191 192 void update_add(int x, int y, int add) 193 { 194 int fx = top[x], fy = top[y]; 195 while(fx != fy) //当两者不在同一条重链上 196 { 197 if(deep[fx] >= deep[fy]) 198 { 199 st.update(1, 1, n, id[fx], id[x], add); 200 x = f[fx]; fx = top[x]; 201 } 202 else 203 { 204 st.update(1, 1, n, id[fy], id[y], add); 205 y = f[fy]; fy = top[y]; 206 } 207 } 208 //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献 209 if(id[x] <= id[y]) st.update(1, 1, n, id[x], id[y], add); 210 else st.update(1, 1, n, id[y], id[x], add); 211 } 212 213 int main() 214 { 215 scanf("%d%d%d", &n, &root, &qcnt); 216 for(int i = 1;i <= n;i++) scanf("%d", &w[i]); 217 for(int i = 1;i < n;i++) 218 { 219 int u, v; 220 scanf("%d%d", &u, &v); 221 addedge(u, v); 222 addedge(v, u); 223 } 224 dfs1(root, -1, 1); 225 dfs2(root, root); 226 227 for(int i = 1;i <= n;i++) printf("%d ", id[i]); 228 printf("\n"); 229 for(int i = 1;i <= n;i++) printf("%d ", rk[i]); 230 printf("\n"); 231 232 st.build(1, 1, n); 233 234 while(qcnt--) 235 { 236 int op; 237 scanf("%d", &op); 238 if(op == 1) 239 { 240 int u, v, add; 241 scanf("%d%d%d", &u, &v, &add); 242 update_add(u, v, add); 243 } 244 else if(op == 2) 245 { 246 int u, v; 247 scanf("%d%d", &u, &v); 248 printf("%d\n", querymax(u, v)); 249 } 250 else if(op == 3) 251 { 252 int u, v; 253 scanf("%d%d", &u, &v); 254 printf("%d\n", querysum(u, v)); 255 } 256 else if(op == 4) 257 { 258 int u, add; 259 scanf("%d%d", &u, &add); 260 st.update(1, 1, n, id[u], id[u]+size[u]-1, add); 261 } 262 else if(op == 5) 263 { 264 int u; 265 scanf("%d", &u); 266 printf("%d\n",st.query1(1, 1, n, 0, id[u], id[u]+size[u]-1)); 267 } 268 else 269 { 270 int u; 271 scanf("%d", &u); 272 printf("%d\n",st.query2(1, 1, n, 0, id[u], id[u]+size[u]-1)); 273 } 274 } 275 return 0; 276 }
参考链接:
1. https://oi-wiki.org/graph/heavy-light-decomposition/
2. https://zhuanlan.zhihu.com/p/41082337
3. https://www.luogu.org/problemnew/solution/P3384
个性签名:时间会解决一切