【模板】树链剖分
[ZJOI2008]树的统计
第一遍树链剖分,打的很难受。
其中拉闸了,检查真是费劲。
树链剖分是什么?
树链剖分,计算机术语,指一种对树进行划分的算法,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、SBT、SPLAY、线段树等)来维护每一条链。
树链剖分可以支持链上求和,链上求最值,链上修改等线段树的操作。
但若断开一条边或者连接两个点,保证两个点连接后依然是棵树。这样树链剖分就虚了,因为线段树不支持这种操作,就需要把线段树换成splay,于是LCT = 树剖 + splay。
说明:
重孩子:儿子节点所有孩子中size最大的
轻孩子:儿子节点中除了重儿子的节点
重边:连接重儿子的边
轻边:连接轻儿子的边
重链:重边连成的链
轻链:轻边连成的链
a[i] 表示节点 i 权值
f[i] 表示节点 i 的父亲在原树中的位置
son[i] 表示节点 i 的重儿子在原树中的位置
top[i] 表示节点 i 所在链的顶端节点在原树中的位置,就是深度最小的
size[i] 表示以 i 为根的子树节点个数
tid[i] 表示树中节点 i 剖分后的新编号
rank[i] 表示剖分后的节点 i 在原树中的位置
deep[i] 表示节点 i 深度,根节点深度为 1
实现方法:
第一遍dfs可以预处理出size,deep,f,son数组
第二遍dfs可以预处理出top,tid,rank数组,通过优先搜索重边,然后搜索轻边
树链剖分目的是把树上的边剖分成一个链,就是一个线段,标号是连续的。
为什么要先搜索重边呢?
可以看出,这样搜可以使得重链上的点的dfs序是连续的,可以用线段树来维护。
如何查询呢?
判断两点是否属于同一条重链,如果属于,就直接修改,因为他们是连续的,如果不属于,就从深度大点开始不停地找他父亲跳轻链,其中深度是不停地在变的,也就是说,两个点可能会轮着跳,直到属于同一个重链。现在看来,轻边实际上是连接重链的东西。
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include <cstdio> 2 #include <cstring> 3 #include <iostream> 4 #define rt 1, 1, n 5 #define ls o << 1, l, m 6 #define rs o << 1 | 1, m + 1, r 7 8 using namespace std; 9 10 const int maxn = 300001; 11 const int INF = 99999999; 12 int n, m, q, cnt, tim; 13 int a[maxn], head[maxn], to[maxn << 2], next[maxn << 2], deep[maxn], size[maxn]; 14 int son[maxn], top[maxn], f[maxn], tid[maxn], rank[maxn], sumv[maxn], maxv[maxn]; 15 //a节点权值, deep节点深度, size以x为根的子树节点个数, son重儿子, top当前节点所在链的顶端节点 16 //f当前节点父亲, tid保存树中每个节点剖分后的新编号, rank保存剖分后的节点在线段树中的位置 17 18 void add(int x, int y) 19 { 20 to[cnt] = y; 21 next[cnt] = head[x]; 22 head[x] = cnt++; 23 } 24 25 void dfs1(int u, int father)//记录所有重边 26 { 27 int i, v; 28 f[u] = father; 29 size[u] = 1; 30 deep[u] = deep[father] + 1; 31 for(i = head[u]; i != -1; i = next[i]) 32 { 33 v = to[i]; 34 if(v == father) continue; 35 dfs1(v, u); 36 size[u] += size[v]; 37 if(son[u] == -1 || size[v] > size[son[u]]) son[u] = v; 38 } 39 } 40 41 void dfs2(int u, int tp) 42 { 43 int i, v; 44 top[u] = tp; 45 tid[u] = ++tim; 46 rank[tim] = u; 47 if(son[u] == -1) return; 48 dfs2(son[u], tp);//重边 49 for(i = head[u]; i != -1; i = next[i]) 50 { 51 v = to[i]; 52 if(v != son[u] && v != f[u]) dfs2(v, v);//轻边 53 } 54 } 55 56 void pushup(int o) 57 { 58 sumv[o] = sumv[o << 1] + sumv[o << 1 | 1]; 59 maxv[o] = max(maxv[o << 1], maxv[o << 1 | 1]); 60 } 61 62 void updata(int o, int l, int r, int d, int x) 63 { 64 int m = (l + r) >> 1; 65 if(l == r) 66 { 67 sumv[o] = maxv[o] = x; 68 return; 69 } 70 if(d <= m) updata(ls, d, x); 71 else updata(rs, d, x); 72 pushup(o); 73 } 74 75 void build(int o, int l, int r) 76 { 77 int m = (l + r) >> 1; 78 if(l == r) 79 { 80 sumv[o] = maxv[o] = a[rank[l]]; 81 return; 82 } 83 build(ls); 84 build(rs); 85 pushup(o); 86 } 87 88 int querymax(int o, int l, int r, int ql, int qr) 89 { 90 int m = (l + r) >> 1, ans = -INF; 91 if(ql <= l && r <= qr) return maxv[o]; 92 if(ql <= m) ans = max(ans, querymax(ls, ql, qr)); 93 if(m < qr) ans = max(ans, querymax(rs, ql, qr)); 94 pushup(o); 95 return ans; 96 } 97 98 int qmax(int u, int v) 99 { 100 int ans = -INF; 101 while(top[u] != top[v])//判断是否在一条重链上 102 { 103 if(deep[top[u]] < deep[top[v]]) swap(u, v);//深度不同,先处理深度大的 104 ans = max(ans, querymax(rt, tid[top[u]], tid[u])); 105 u = f[top[u]]; 106 } 107 if(deep[u] < deep[v]) swap(u, v);//在同一条重链上了 108 ans = max(ans, querymax(rt, tid[v], tid[u])); 109 return ans; 110 } 111 112 int querysum(int o, int l, int r, int ql, int qr) 113 { 114 int m = (l + r) >> 1, ans = 0; 115 if(ql <= l && r <= qr) return sumv[o]; 116 if(ql <= m) ans += querysum(ls, ql, qr); 117 if(m < qr) ans += querysum(rs, ql, qr); 118 pushup(o); 119 return ans; 120 } 121 122 int qsum(int u, int v) 123 { 124 int ans = 0; 125 while(top[u] != top[v])//判断是否在一条重链上 126 { 127 if(deep[top[u]] < deep[top[v]]) swap(u, v);//深度不同,先处理深度大的 128 ans += querysum(rt, tid[top[u]], tid[u]); 129 u = f[top[u]]; 130 } 131 if(deep[u] < deep[v]) swap(u, v);//在同一条重链上了 132 ans += querysum(rt, tid[v], tid[u]); 133 return ans; 134 } 135 136 int main() 137 { 138 int i, j, x, y; 139 char s[11]; 140 memset(head, -1, sizeof(head)); 141 memset(son, -1, sizeof(son)); 142 scanf("%d", &n); 143 for(i = 1; i < n; i++) 144 { 145 scanf("%d %d", &x, &y); 146 add(x, y); 147 add(y, x); 148 } 149 for(i = 1; i <= n; i++) scanf("%d", &a[i]); 150 dfs1(1, 1);//根节点和他的父亲 151 dfs2(1, 1);//根节点和链头结点 152 build(rt); 153 scanf("%d", &q); 154 for(i = 1; i <= q; i++) 155 { 156 scanf("%s %d %d", s, &x, &y); 157 if(s[1] == 'H') updata(rt, tid[x], y);//把位置为x的点修改为y 158 if(s[1] == 'M') printf("%d\n", qmax(x, y)); 159 if(s[1] == 'S') printf("%d\n", qsum(x, y)); 160 } 161 return 0; 162 }
检查了n遍代码,检查了n遍函数,都没有检查出错误来。
最后偶然发现只是因为调用错了函数。。。
这道题要使一个子树的值都加x,且统计子树所有节点的值的和。
可以发现,一个子树上的点在线段树中的编号是连续的,所以可以对区间 ( tim[x], tim[x] + size[x] - 1 ) 进行操作。
——代码
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include <cstdio> 2 #include <cstring> 3 #include <iostream> 4 #define LL long long 5 #define rt 1, 1, n 6 #define ls o << 1, l, m 7 #define rs o << 1 | 1, m + 1, r 8 9 using namespace std; 10 11 const int maxn = 100001; 12 int n, m, s, cnt, tim; 13 int head[maxn], next[maxn << 2], to[maxn << 2], deep[maxn], f[maxn], size[maxn], son[maxn], top[maxn], rank[maxn], tid[maxn]; 14 LL p, a[maxn], sumv[maxn << 2], addv[maxn << 2]; 15 16 void add(int x, int y) 17 { 18 to[cnt] = y; 19 next[cnt] = head[x]; 20 head[x] = cnt++; 21 } 22 23 void dfs1(int u, int father) 24 { 25 int i, v; 26 size[u] = 1; 27 f[u] = father; 28 deep[u] = deep[father] + 1; 29 for(i = head[u]; i != -1; i = next[i]) 30 { 31 v = to[i]; 32 if(v == father) continue; 33 dfs1(v, u); 34 size[u] += size[v]; 35 if(son[u] == -1 || size[v] > size[son[u]]) son[u] = v; 36 } 37 } 38 39 void dfs2(int u, int tp) 40 { 41 int i, v; 42 top[u] = tp; 43 tid[u] = ++tim; 44 rank[tim] = u; 45 if(son[u] == -1) return; 46 dfs2(son[u], tp); 47 for(i = head[u]; i != -1; i = next[i]) 48 { 49 v = to[i]; 50 if(v != son[u] && v != f[u]) dfs2(v, v); 51 } 52 } 53 54 void pushup(int o) 55 { 56 sumv[o] = (sumv[o << 1] + sumv[o << 1 | 1]) % p; 57 } 58 59 void pushdown(int o, int len) 60 { 61 addv[o << 1] = (addv[o << 1] + addv[o]) % p; 62 addv[o << 1 | 1] = (addv[o << 1 | 1] + addv[o]) % p; 63 sumv[o << 1] = (sumv[o << 1] + addv[o] * (len - (len >> 1))) % p; 64 sumv[o << 1 | 1] = (sumv[o << 1 | 1] + addv[o] * (len >> 1)) % p; 65 addv[o] = 0; 66 } 67 68 void build(int o, int l, int r) 69 { 70 if(l == r) 71 { 72 sumv[o] = a[rank[l]] % p; 73 return; 74 } 75 int m = (l + r) >> 1; 76 build(ls); 77 build(rs); 78 pushup(o); 79 } 80 81 void updata(int o, int l, int r, int ql, int qr, LL d) 82 { 83 if(ql <= l && r <= qr) 84 { 85 addv[o] = (addv[o] + d) % p; 86 sumv[o] = (sumv[o] + d * (r - l + 1)) % p; 87 return; 88 } 89 if(l > qr || r < ql) return; 90 if(addv[o]) pushdown(o, r - l + 1); 91 int m = (l + r) >> 1; 92 updata(ls, ql, qr, d); 93 updata(rs, ql, qr, d); 94 pushup(o); 95 } 96 97 void qdata(int u, int v, LL d) 98 { 99 while(top[u] != top[v]) 100 { 101 if(deep[top[u]] < deep[top[v]]) swap(u, v); 102 updata(rt, tid[top[u]], tid[u], d); 103 u = f[top[u]]; 104 } 105 if(deep[u] > deep[v]) swap(u, v); 106 updata(rt, tid[u], tid[v], d); 107 } 108 109 LL querysum(int o, int l, int r, int ql, int qr) 110 { 111 if(ql <= l && r <= qr) return sumv[o]; 112 if(l > qr || r < ql) return 0; 113 if(addv[o]) pushdown(o, r - l + 1); 114 int m = (l + r) >> 1; 115 return (querysum(ls, ql, qr) + querysum(rs, ql, qr)) % p; 116 } 117 118 LL qsum(int u, int v) 119 { 120 LL ans = 0; 121 while(top[u] != top[v]) 122 { 123 if(deep[top[u]] < deep[top[v]]) swap(u, v); 124 ans = (ans + querysum(rt, tid[top[u]], tid[u])) % p; 125 u = f[top[u]]; 126 } 127 if(deep[u] > deep[v]) swap(u, v); 128 ans = (ans + querysum(rt, tid[u], tid[v])) % p; 129 return ans; 130 } 131 132 int main() 133 { 134 int i, j, c, x, y; 135 LL z; 136 scanf("%d %d %d %lld", &n, &m, &s, &p); 137 for(i = 1; i <= n; i++) scanf("%lld", &a[i]); 138 memset(head, -1, sizeof(head)); 139 memset(son, -1, sizeof(son)); 140 for(i = 1; i < n; i++) 141 { 142 scanf("%d %d", &x, &y); 143 add(x, y); 144 add(y, x); 145 } 146 dfs1(s, s);//根节点和他的父亲 147 dfs2(s, s);//根节点和链头结点 148 build(rt); 149 for(i = 1; i <= m; i++) 150 { 151 scanf("%d", &c); 152 if(c == 1) 153 { 154 scanf("%d %d %lld", &x, &y, &z); 155 qdata(x, y, z % p); 156 } 157 else if(c == 2) 158 { 159 scanf("%d %d", &x, &y); 160 printf("%lld\n", qsum(x, y)); 161 } 162 else if(c == 3) 163 { 164 scanf("%d %lld", &x, &z); 165 updata(rt, tid[x], tid[x] + size[x] - 1, z % p); 166 } 167 else 168 { 169 scanf("%d", &x); 170 printf("%lld\n", querysum(rt, tid[x], tid[x] + size[x] - 1)); 171 } 172 } 173 return 0; 174 }
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步