树链剖分整理
树链剖分整理总结
问题的设置:
对于一棵树(无向无环连通图),为每个结点分配对应的权重。要求能高效计算任意两个结点之间的路径的各类信息,其中包括路径长度(路径上所有结点的权重加总),路径中最大权重,最小权重等等。到这里一切都还是比较简单的,我们可以利用 Tarjan 的 LCA 算法在线性时间复杂度内快速求解。但是如果还要求允许动态修改任意结点的权值,那么问题就不简单了。很容易发现这有些类似于线段树的问题,但是线段树是作用在区间上,而我们的问题是发生在树上的,因此线段树并不适用。而树链剖分则是可以用于求解这类问题的高效算法,其更新权值的时间复杂度为 $O(logn)$ 而统计路径信息的时间复杂度为 $O(log^2n)$。
树链剖分通常用于处理树的形态不变 但点权/边权需要修改查询的题目。
理解:
树链剖分就是将树分割成多条互不相交的轻重链,然后利用数据结构(线段树、树状数组等)来维护这些链。
首先就是一些必须知道的概念:
- 重节点:子树结点数目最多的节点;
- 轻节点:除了重结点以外的节点;
- 重边:父亲节点和重节点连成的边;
- 轻边:父亲节点和轻节点连成的边;
- 重链:由多条重边连接而成的路径;
- 轻链:由多条轻边连接而成的路径;
比如上面这幅图中,粗黑色的边就是重边,连续的重边连起来就是重链,1,2,3,5,12,8,10 为轻节点,其余为重节点,用红点标记的就是该结点所在链的起点,还有每条边的值其实是进行 dfs 时的执行序号。理解树链剖分算法其实不用管轻节点和轻链,把底端的轻节点看成由本身一个节点形成的重链,这样所有节点都在重链上。
具体实现:
1. 一条重链在线段树上是一段连续的区间。
2. 一个节点的重儿子就是这个节点的子节点中子树最大的点。
3. 记录链需要记录链最顶上的点,记作 top。具体实现,是基于两次 dfs。第一次 dfs 求出重儿子。第二次 dfs,通过节点映射时间戳,优先 dfs 重儿子,得出节点新编号,就可以得到上述需要满足的点。
时间复杂度:
性质1:如果 $(u, v)$ 是一条轻边,那么 $size(v) < size(u)/2$;
证明:由于任一轻儿子对应的子树大小要小于父节点所对应子树大小的一半,因此从一个轻儿子沿轻边向上走到父节点后 所对应的子树大小至少变为两倍以上;
性质2:从任一节点向根节点走,走过的重链和轻边的条数都是 $logn$ 级别以内的;
证明:基于性质1,显然经过的轻边条数是不超过 $logn$ 的,然后由于重链都是间断的 (连续的可以合成一条),所以经过的重链的条数是不超过轻边条数 $+1$ 的,因此经过重链的条数也是 $logn$ 级别的;
通过链的划分,就可以满足在求 lca 的时候,因为基于线段树,处理节点到所在重链顶端的数据的复杂度就是 $O(logn)$,又由性质2,则统计或修改路径信息的时间复杂度为 $O(log^2n)$,修改单点权值时间复杂度为 $O(logn)$;
树链的建立只涉及到了深度优先搜索,时间复杂度为 $O(n)$,而线段树在知道初始值时可以以 $O(n)$ 时间复杂度建立,因此预处理的时间复杂度为 $O(n)$。
所以总的时间复杂度为 $O(n + Qlog^2n)$。( $Q$ 为询问和操作次数)。
实战
1. 洛谷P3384【模板】树链剖分 (基于点权,路径、子树更新、查询)
树链剖分+线段树模板代码:
1 #include <stdio.h> 2 #include <iostream> 3 #define REP(i, a, b) for (int i = a; i <= b; i++) 4 using namespace std; 5 typedef long long LL; 6 const int MAXN = 110000; 7 struct Node { 8 int to, next; 9 } edg[MAXN<<1]; 10 struct segmentTree { 11 int left, right; 12 LL sum, tag; 13 } tree[MAXN<<2]; 14 int head[MAXN], siz[MAXN], top[MAXN], hson[MAXN], dep[MAXN], fa[MAXN], id[MAXN], rnk[MAXN]; 15 /* 16 head[u]:前向星储存边中保存节点u最后添加的边的编号 17 siz[u]:保存以u为根的子树的节点个数 18 top[u]:保存当前节点所在链的顶端节点 19 hson[u]:保存节点u的重儿子 20 dep[u]:保存节点u的深度值 21 fa[u]:保存节点u的父亲节点编号 22 id[u]:保存节点u剖分以后的新编号(dfs2执行顺序) 23 rnk[u]:u为新编号,rnk[u]为原编号 24 */ 25 int N, M, R, A[MAXN], idx = 0, dfs_cnt = 0; 26 LL mod; 27 inline int read() { // 读入优化 28 int x = 0, f = 1; char ch = getchar(); 29 while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); } 30 while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); } 31 return x * f; 32 } 33 inline void adde(int u, int v) { // 链式前向星添加u向v的边 34 edg[++idx].to = v; edg[idx].next = head[u]; head[u] = idx; 35 } 36 /* 第一次DFS可以得到当前节点的父亲结点(fa数组)、 37 当前结点的深度值(dep数组)、当前结点的子结点数量(siz数组)、 38 当前结点的重结点(hson数组) */ 39 void dfs1(int u, int father, int depth) { 40 dep[u] = depth; 41 fa[u] = father; 42 siz[u] = 1; 43 for (int i = head[u]; i; i = edg[i].next) { 44 int v = edg[i].to; 45 if (v != fa[u]) { 46 dfs1(v, u, depth + 1); 47 siz[u] += siz[v]; 48 if (hson[u] == -1 || siz[v] > siz[hson[u]]) hson[u] = v; 49 } 50 } 51 } 52 /* 第二次DFS的时候则可以将各个重节点连接成重链,轻节点连接成轻链, 53 并且将重链(其实就是一段区间)用数据结构(一般是树状数组或线段树) 54 来进行维护,为每个节点进行编号,其实就是DFS在执行时的顺序(id数组), 55 以及当前节点所在链的起点(top数组),还有当前节点在树中的位置(rnk数组)。 */ 56 void dfs2(int u, int t) { 57 id[u] = ++dfs_cnt; rnk[dfs_cnt] = u; top[u] = t; 58 if (!hson[u]) return ; 59 dfs2(hson[u], t); 60 for (int i = head[u]; i; i = edg[i].next) { 61 int v = edg[i].to; 62 if (v != hson[u] && v != fa[u]) dfs2(v, v); 63 } 64 } 65 void buildtree(int i, int l, int r) { // 建线段树 66 tree[i].left = l; tree[i].right = r; 67 // 注意线段树维护的节点的编号是dfs2后的新编号,不是原节点编号 68 // 因此此处通过rnk[l]得出原节点编号 69 if (l == r) tree[i].sum = A[rnk[l]]; 70 else { 71 int mid = (l + r) >> 1; 72 buildtree(i << 1, l , mid); 73 buildtree(i << 1 | 1, mid + 1, r); 74 tree[i].sum = tree[i<<1].sum + tree[i<<1|1].sum; 75 } 76 } 77 void pushdown(int i) { // 线段树上i节点向下更新一层 78 int l = i << 1, r = i << 1 | 1; 79 tree[l].sum += (tree[l].right - tree[l].left + 1) * tree[i].tag; 80 tree[r].sum += (tree[r].right - tree[r].left + 1) * tree[i].tag; 81 tree[l].tag += tree[i].tag; 82 tree[r].tag += tree[i].tag; 83 tree[i].tag = 0; 84 } 85 void update(int i, int x, int y, LL z) { // 当前线段树上节点为i,在区间[x, y]上成段增加z 86 int l = i << 1, r = i << 1 | 1; 87 if (tree[i].left > y || tree[i].right < x) return ; 88 if (x <= tree[i].left && tree[i].right <= y) { 89 tree[i].sum += (tree[i].right - tree[i].left + 1) * z; 90 tree[i].tag += z; 91 } else { 92 if (tree[i].tag) pushdown(i); 93 update(l, x, y, z); 94 update(r, x, y, z); 95 tree[i].sum = tree[l].sum + tree[r].sum; 96 } 97 } 98 LL query(int i, int x, int y) { // 当前线段树上节点为i,查询区间[x, y]的和 99 int l = i << 1, r = i << 1 | 1; 100 if (x <= tree[i].left && tree[i].right <= y) return tree[i].sum; 101 if (tree[i].left > y || tree[i].right < x) return 0; 102 if (tree[i].tag) pushdown(i); 103 return query(l, x, y) + query(r, x, y); 104 } 105 void update_path(int u, int v, LL z) { // 从u到v节点最短路径上所有节点的值都加上z 106 int tu = top[u], tv = top[v]; 107 while (tu != tv) { 108 if (dep[tu] < dep[tv]) swap(u, v), swap(tu, tv); 109 update(1, id[tu], id[u], z); 110 u = fa[tu], tu = top[u]; 111 } 112 if (dep[u] < dep[v]) swap(u, v); 113 update(1, id[v], id[u], z); 114 } 115 LL query_path(int u, int v) { // 求树从u到v结点最短路径上所有节点的值之和 116 LL res = 0; 117 int tu = top[u], tv = top[v]; 118 // 计算节点到所在重链起始点的路径和直到两个节点在同一条链上 119 // 注意每次循环只能跳一次,并且让结点深的那个来跳到top的位置,避免两个一起跳从而插肩而过。 120 while (tu != tv) { 121 if (dep[tu] < dep[tv]) swap(u, v), swap(tu, tv); 122 res += query(1, id[tu], id[u]); 123 u = fa[tu], tu = top[u]; 124 } 125 // 即使两个节点在同一条链上,仍需要再计算一次路径和,因为最后一次更新u,v后是没有计算的 126 if (dep[u] < dep[v]) swap(u, v); 127 return res + query(1, id[v], id[u]); 128 } 129 130 int main() { 131 N = read(), M = read(), R = read(), mod = read(); 132 REP(i, 1, N) A[i] = read(); 133 REP(i, 2, N) { 134 int u = read(), v = read(); 135 adde(u, v); adde(v, u); 136 } 137 dfs1(R, 0, 1); 138 dfs2(R, R); 139 buildtree(1, 1, N); 140 while (M--) { 141 int opt = read(); 142 switch (opt) { 143 case 1: { 144 int x = read(), y = read(); 145 LL z; scanf("%lld", &z); 146 update_path(x, y, z); 147 break; 148 } 149 case 2: { 150 int x = read(), y = read(); 151 printf("%lld\n", query_path(x, y) % mod); 152 break; 153 } 154 case 3: { 155 int x = read(); 156 LL z; scanf("%lld", &z); 157 update(1, id[x], id[x] + siz[x] - 1, z); 158 break; 159 } 160 case 4: { 161 int x = read(); 162 printf("%lld\n", query(1, id[x], id[x] + siz[x] - 1) % mod); 163 } 164 } 165 } 166 return 0; 167 }
2. POJ 2763 -- Housewife Wind (基于边权,单点更新,区间查询)
基于边权的树链剖分,对于边来说,除了根节点外,其他所有的边都可以与其子节点一一对应,这样就转换为了点权问题,在处理时要注意LCA当前点不统计,因为此时的边权不会被访问到。
树链剖分+树状数组代码:
1 #include <iostream> 2 #include <stdio.h> 3 #include <cstring> 4 #define MAXN 100010 5 using namespace std; 6 typedef long long LL; 7 struct edg{ 8 int to, next, val; 9 } edg[MAXN<<1]; 10 int idx, tol, head[MAXN], cost[MAXN], c[MAXN], xx[MAXN], yy[MAXN]; 11 // cost映射树上原编号的节点值,c为树状数组 12 int hson[MAXN], fa[MAXN], id[MAXN], dep[MAXN], siz[MAXN], top[MAXN]; 13 int n, q, s; 14 // 前向星添加u向v的边 15 void add_edg(int u, int v, int z) { 16 edg[++idx].to = v; 17 edg[idx].val = z; 18 edg[idx].next = head[u]; 19 head[u] = idx; 20 } 21 // 取x的二进制中1的最低位代表的值 22 int lowbit(int x) { 23 return x & -x; 24 } 25 // 树状数组中向原树上编号为x(新编号)的节点加上值val 26 void update(int x, int val) { 27 for (; x <= n; x += lowbit(x)) c[x] += val; 28 } 29 // 树状数组查询原树上[1, x]的节点值之和 30 LL query(int x) { 31 if (!x) return 0; 32 LL sum = 0; 33 for (; x > 0; x -= lowbit(x)) sum += c[x]; 34 return sum; 35 } 36 // 查询u、v之间最短路径边权和(u, v为树上原编号) 37 LL query_path(int u, int v) { 38 LL ans = 0; 39 int tu = top[u], tv = top[v]; 40 while (tu != tv) { 41 if (dep[tu] < dep[tv]) swap(tu, tv), swap(u, v); 42 ans += query(id[u]) - query(id[tu] - 1); 43 u = fa[tu]; 44 tu = top[u]; 45 } 46 if (dep[u] < dep[v]) swap(u, v); 47 ans += query(id[u]) - query(id[v]); 48 return ans; 49 } 50 void dfs1(int u, int father, int depth) { 51 dep[u] = depth; 52 fa[u] = father; 53 siz[u]++; 54 for (int i = head[u]; ~i; i = edg[i].next) { 55 int v = edg[i].to; 56 if (v != fa[u]) { 57 cost[v] = edg[i].val; 58 dfs1(v, u, depth + 1); 59 siz[u] += siz[v]; 60 if (hson[u] == -1 || siz[hson[u]] < siz[v]) hson[u] = v; 61 } 62 } 63 } 64 void dfs2(int u, int t) { 65 id[u] = ++tol; 66 update(id[u], cost[u]); 67 top[u] = t; 68 if (hson[u] == -1) return ; 69 dfs2(hson[u], t); 70 for (int i = head[u]; ~i; i = edg[i].next) { 71 int v = edg[i].to; 72 if (v != fa[u] && v != hson[u]) dfs2(v, v); 73 } 74 } 75 void init() { 76 cost[1] = cost[0] = idx = tol = 0; 77 memset(siz, 0, sizeof(siz)); 78 memset(top, -1, sizeof(top)); 79 memset(hson, -1, sizeof(hson)); 80 memset(c, 0, sizeof(c)); 81 memset(head, -1, sizeof(head)); 82 } 83 84 int main() { 85 while (~scanf("%d %d %d", &n, &q, &s)) { 86 init(); 87 for (int i = 1; i < n; i++) { 88 int z; 89 scanf("%d %d %d", &xx[i], &yy[i], &z); 90 add_edg(xx[i], yy[i], z); 91 add_edg(yy[i], xx[i], z); 92 } 93 dfs1(1, -1, 1); 94 dfs2(1, 1); 95 while (q--) { 96 int op; 97 scanf("%d", &op); 98 if (op == 0) { 99 int to; 100 scanf("%d", &to); 101 printf("%lld\n", query_path(s, to)); 102 s = to; 103 } else { 104 int ith, newv; 105 scanf("%d %d", &ith, &newv); 106 if (dep[xx[ith]] < dep[yy[ith]]) swap(xx[ith], yy[ith]); 107 update(id[xx[ith]], newv - (query(id[xx[ith]]) - query(id[xx[ith]] - 1))); 108 } 109 } 110 } 111 return 0; 112 }
3.POJ 3237 -- Tree (基于边权,区间标记)
树链剖分+线段树代码:
1 #include <stdio.h> 2 #include <iostream> 3 #include <string> 4 #include <cstring> 5 #define REP(i, a, b) for (int i = a; i <= b; i++) 6 #define INF 0x3f3f3f3f 7 using namespace std; 8 typedef long long LL; 9 const int MAXN = 10010; 10 struct Node { 11 int to, next, val; 12 } edg[MAXN<<1]; 13 // 线段树维护区间最大最小值,tag为取相反数标记 14 struct segmentTree { 15 int left, right; 16 int maxx, minn; 17 int tag; 18 } tree[MAXN<<2]; 19 int head[MAXN], siz[MAXN], top[MAXN], hson[MAXN], dep[MAXN], fa[MAXN], id[MAXN]; 20 int n, xx[MAXN], yy[MAXN], val1[MAXN], val2[MAXN], idx, dfs_cnt; 21 // val1映射树上原编号的节点值,val2映射dfs2后新编号的节点值 22 LL mod; 23 // 读入优化 24 inline int read() { 25 int x = 0, f = 1; char ch = getchar(); 26 while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); } 27 while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); } 28 return x * f; 29 } 30 void init() { 31 val1[0] = val1[1] = idx = dfs_cnt = 0; 32 memset(head, -1, sizeof(head)); 33 memset(hson, -1, sizeof(hson)); 34 memset(siz, 0, sizeof(siz)); 35 } 36 // 前向星添加u向v的边 37 inline void add_edge(int u, int v, int z) { 38 edg[++idx].to = v; 39 edg[idx].val = z; 40 edg[idx].next = head[u]; 41 head[u] = idx; 42 } 43 void dfs1(int u, int father, int depth) { 44 dep[u] = depth; 45 fa[u] = father; 46 siz[u] = 1; 47 for (int i = head[u]; ~i; i = edg[i].next) { 48 int v = edg[i].to; 49 if (v != fa[u]) { 50 val1[v] = edg[i].val; 51 dfs1(v, u, depth + 1); 52 siz[u] += siz[v]; 53 if (hson[u] == -1 || siz[v] > siz[hson[u]]) hson[u] = v; 54 } 55 } 56 } 57 void dfs2(int u, int t) { 58 id[u] = ++dfs_cnt; top[u] = t, val2[id[u]] = val1[u]; 59 if (!~hson[u]) return ; 60 dfs2(hson[u], t); 61 for (int i = head[u]; ~i; i = edg[i].next) { 62 int v = edg[i].to; 63 if (v != hson[u] && v != fa[u]) dfs2(v, v); 64 } 65 } 66 // 更新线段树编号为k的节点的最大最小值 67 void pushup(int k) { 68 tree[k].maxx = max(tree[k<<1].maxx, tree[k<<1|1].maxx); 69 tree[k].minn = min(tree[k<<1].minn, tree[k<<1|1].minn); 70 } 71 // 下放取相反数标记,更新k的左右儿子的最大最小值 72 void pushdown(int k) { 73 if (!tree[k].tag) return ; 74 tree[k<<1].tag ^= 1, tree[k<<1|1].tag ^= 1; 75 swap(tree[k<<1].minn, tree[k<<1].maxx); 76 tree[k<<1].maxx = -tree[k<<1].maxx; 77 tree[k<<1].minn = -tree[k<<1].minn; 78 swap(tree[k<<1|1].minn, tree[k<<1|1].maxx); 79 tree[k<<1|1].maxx = -tree[k<<1|1].maxx; 80 tree[k<<1|1].minn = -tree[k<<1|1].minn; 81 tree[k].tag = 0; 82 } 83 // 建线段树 84 void build_tree(int i, int l, int r) { 85 tree[i].left = l; tree[i].right = r, tree[i].tag = 0; 86 if (l == r) tree[i].minn = tree[i].maxx = val2[l]; 87 else { 88 int mid = (l + r) >> 1; 89 build_tree(i << 1, l , mid); 90 build_tree(i << 1 | 1, mid + 1, r); 91 pushup(i); 92 } 93 } 94 // 当前线段树上节点编号为i,将pos位置的值改变为z(pos为dfs2后的新编号) 95 void update1(int i, int pos, int z) { 96 if (pos == tree[i].left && pos == tree[i].right) { 97 tree[i].minn = tree[i].maxx = z; 98 } else { 99 pushdown(i); 100 int l = tree[i].left, r = tree[i].right; 101 int mid = (l + r) >> 1; 102 if (pos <= mid) update1(i << 1, pos, z); 103 else update1(i << 1 | 1, pos, z); 104 pushup(i); 105 } 106 } 107 // 当前线段树上节点编号为i,将[tl, tr]上的值取相反数 108 void update2(int i, int tl, int tr) { 109 if (tree[i].left >= tl && tree[i].right <= tr) { 110 tree[i].tag ^= 1; 111 swap(tree[i].minn, tree[i].maxx); 112 tree[i].minn = -tree[i].minn; 113 tree[i].maxx = -tree[i].maxx; 114 } else { 115 pushdown(i); 116 int l = tree[i].left, r = tree[i].right; 117 int mid = (l + r) >> 1; 118 if (tl <= mid) update2(i << 1, tl, tr); 119 if (tr > mid) update2(i << 1 | 1, tl, tr); 120 pushup(i); 121 } 122 } 123 // 对u、v节点最短路径上的边权取相反数(u, v为原树上节点编号) 124 void Change(int u, int v) { 125 int tu = top[u], tv = top[v]; 126 while (tu != tv) { 127 if (dep[tu] < dep[tv]) swap(u, v), swap(tu, tv); 128 update2(1, id[tu], id[u]); 129 u = fa[tu], tu = top[u]; 130 } 131 if (u == v) return ; 132 if (dep[u] < dep[v]) swap(u, v); 133 update2(1, id[v] + 1, id[u]); 134 } 135 // 当前线段树上节点编号为i,查询[tl, tr]范围的最大值 136 int query(int i, int tl, int tr) { 137 int l = i << 1, r = i << 1 | 1; 138 if (tl <= tree[i].left && tree[i].right <= tr) return tree[i].maxx; 139 if (tree[i].left > tr || tree[i].right < tl) return -INF; 140 pushdown(i); 141 return max(query(l, tl, tr), query(r, tl, tr)); 142 } 143 // 查询原树上节点u、v之间最短路径边权最大值 144 int Query_max(int u, int v) { 145 if (u == v) return 0; 146 int ans = -INF; 147 int tu = top[u], tv = top[v]; 148 while (tu != tv) { 149 if (dep[tu] < dep[tv]) swap(u, v), swap(tu, tv); 150 ans = max(ans, query(1, id[tu], id[u])); 151 u = fa[tu], tu = top[u]; 152 } 153 if (u == v) return ans; 154 if (dep[u] < dep[v]) swap(u, v); 155 return ans = max(ans, query(1, id[v] + 1, id[u])); 156 } 157 158 int main() { 159 int T = read(); 160 while (T--) { 161 init(); 162 n = read(); 163 for (int i = 1; i < n; i++) { 164 xx[i] = read(), yy[i] = read(); 165 int z = read(); 166 add_edge(xx[i], yy[i], z); 167 add_edge(yy[i], xx[i], z); 168 } 169 dfs1(1, -1, 1); 170 dfs2(1, 1); 171 build_tree(1, 1, n); 172 string s; 173 while (cin >> s && s != "DONE") { 174 if (s == "QUERY") { 175 int u = read(), v = read(); 176 printf("%d\n", Query_max(u, v)); 177 } 178 else if (s == "CHANGE") { 179 int ith = read(), newv = read(); 180 if (dep[xx[ith]] < dep[yy[ith]]) swap(xx[ith], yy[ith]); 181 update1(1, id[xx[ith]], newv); 182 } 183 else { 184 int u = read(), v = read(); 185 Change(u, v); 186 } 187 } 188 } 189 return 0; 190 }
部分内容修改自:
https://zhuanlan.zhihu.com/p/30286758?utm_source=wechat_session&utm_medium=social
https://www.cnblogs.com/George1994/p/7821357.html
https://www.cnblogs.com/dalt/p/8206664.html