P3384——树链剖分&&模板
题目描述
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
解决方法
用链式前向星的方式保存树,两次DFS将树剖分成若干重链和轻链,套用线段树进行更新和查询,对子树的操作可以转化成连续节点间的操作(因为DFS时子树节点的编号也是连续的),注意取模和开$long \ \ long$.
而且单独$add$标记时是不用下推的,只需查询时累加即可(不知道为什么那些题解都用下推的)
1 #include<bits/stdc++.h> 2 using namespace std; 3 4 typedef long long ll; 5 #define lc o <<1 6 #define rc o <<1 | 1 7 //const ll INF = 0x3f3f3f3f; 8 const int maxn = 200000 + 10; 9 struct Edge 10 { 11 int to, next; 12 }edges[2*maxn]; 13 int head[maxn]; 14 int cur, f[maxn], deep[maxn], size[maxn], son[maxn], rk[maxn], id[maxn], top[maxn], cnt; 15 int n, q, w[maxn], root, mod; 16 17 inline void addedge(int u, int v) 18 { 19 ++cur; 20 edges[cur].next = head[u]; 21 head[u] = cur; 22 edges[cur].to = v; 23 } 24 25 struct SegTree{ 26 ll sum[maxn << 2], addv[maxn << 2]; 27 void build(int o, int l, int r) 28 { 29 if(l == r) 30 { 31 sum[o] = w[rk[l]] % mod; 32 } 33 else 34 { 35 int mid = (l + r) >> 1; 36 build(lc, l, mid); 37 build(rc, mid+1, r); 38 sum[o] = (sum[lc] + sum[rc]) % mod; 39 } 40 } 41 42 void maintain(int o, int l, int r) 43 { 44 if(l == r) //如果是叶子结点 45 sum[o] = w[rk[l]] % mod; 46 else //如果是非叶子结点 47 sum[o] = (sum[lc] + sum[rc]) % mod; 48 49 sum[o] = (sum[o] + addv[o] * (r-l+1)) % mod; 50 } 51 //区间修改,[cl,cr] += v; 52 void update(int o, int l, int r, int cl, int cr, int v) // 53 { 54 if(cl <= l && r <= cr) addv[o] = (addv[o] + v) % mod; 55 else 56 { 57 int m = l + (r-l) /2; 58 if(cl <= m) update(lc, l, m, cl, cr, v); 59 if(cr > m) update(rc, m+1, r, cl, cr, v); 60 } 61 maintain(o, l, r); 62 } 63 64 //区间查询,sum{ql,qr} 65 ll query(int o, int l,int r, ll add, int ql, int qr) 66 { 67 if(ql <= l && r <= qr) 68 { 69 //prllf("sum[o]:%d %d*(%d-%d+1)\n", sum[o], add, r, l); 70 return (sum[o] + add * (r-l+1)) % mod; //tx l-r+1 71 } 72 else 73 { 74 int m = l + (r - l) / 2; 75 ll ans = 0; 76 add = (add + addv[o]) % mod; 77 if(ql <= m) ans = (ans + query(lc, l, m, add, ql, qr)) % mod; 78 if(qr > m) ans = (ans + query(rc, m+1, r, add, ql, qr)) % mod; 79 return ans; 80 } 81 } 82 }st; 83 84 void dfs1(int u, int fa, int depth) //当前节点、父节点、层次深度 85 { 86 //prllf("u:%d fa:%d depth:%d\n", u, fa, depth); 87 f[u] = fa; 88 deep[u] = depth; 89 size[u] = 1; //这个点本身的size 90 for(int i = head[u];i;i = edges[i].next) 91 { 92 int v = edges[i].to; 93 if(v == fa) continue; 94 dfs1(v, u, depth+1); 95 size[u] += size[v]; //子节点的size已被处理,用它来更新父节点的size 96 if(size[v] > size[son[u]]) son[u] = v; //选取size最大的作为重儿子 97 } 98 } 99 100 void dfs2(int u, int t) //当前节点、重链顶端 101 { 102 //prllf("u:%d t:%d\n", u, t); 103 top[u] = t; 104 id[u] = ++cnt; //标记dfs序 105 rk[cnt] = u; //序号cnt对应节点u 106 if(!son[u]) return; //没有儿子? 107 dfs2(son[u], t); //我们选择优先进入重儿子来保证一条重链上各个节点dfs序连续 108 109 for(int i = head[u];i;i = edges[i].next) 110 { 111 int v = edges[i].to; 112 if(v != son[u] && v != f[u]) dfs2(v, v); //这个点位于轻链顶端,那么它的top必然为它本身 113 } 114 } 115 116 117 118 /*修改和查询的原理是一致的,以查询操作为例,其实就是个LCA,不过这里要使用top数组加速,因为top可以直接跳到该重链的起始顶点*/ 119 /*注意,每次循环只能跳一次,并且让结点深的那个跳到top的位置,避免两者一起跳而插肩而过*/ 120 ll querysum(int x, int y) 121 { 122 int fx = top[x], fy = top[y]; 123 ll ans = 0; 124 while(fx != fy) //当两者不在同一条重链上 125 { 126 if(deep[fx] >= deep[fy]) 127 { 128 //prllf("%d %d\n", id[fx], id[x]); 129 ans = (ans + st.query(1, 1, n, 0, id[fx], id[x])) % mod; //线段树区间求和,计算这条重链的贡献 130 x = f[fx]; fx = top[x]; 131 } 132 else 133 { 134 //prllf("%d %d\n", id[fy], id[y]); 135 ans = (ans + st.query(1, 1, n, 0, id[fy], id[y])) % mod; 136 y = f[fy]; fy = top[y]; 137 } 138 } 139 140 //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献 141 if(id[x] <= id[y]) 142 { 143 //prllf("%d %d\n", id[x], id[y]); 144 ans = (ans + st.query(1, 1, n, 0, id[x], id[y])) % mod; 145 } 146 else 147 { 148 //prllf("%d %d\n", id[y], id[x]); 149 ans = (ans + st.query(1, 1, n, 0, id[y], id[x])) % mod; 150 } 151 return ans; 152 } 153 154 void update_add(int x, int y, int add) 155 { 156 int fx = top[x], fy = top[y]; 157 while(fx != fy) //当两者不在同一条重链上 158 { 159 if(deep[fx] >= deep[fy]) 160 { 161 st.update(1, 1, n, id[fx], id[x], add); 162 x = f[fx]; fx = top[x]; 163 } 164 else 165 { 166 st.update(1, 1, n, id[fy], id[y], add); 167 y = f[fy]; fy = top[y]; 168 } 169 } 170 //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献 171 if(id[x] <= id[y]) st.update(1, 1, n, id[x], id[y], add); 172 else st.update(1, 1, n, id[y], id[x], add); 173 } 174 175 176 int main() 177 { 178 scanf("%d%d%d%d", &n, &q, &root, &mod); 179 for(int i = 1;i <= n;i++) 180 { 181 scanf("%d", &w[i]); 182 w[i] %= mod; 183 } 184 for(int i = 1;i < n;i++) 185 { 186 int u, v; 187 scanf("%d%d", &u, &v); 188 addedge(u, v); 189 addedge(v, u); 190 } 191 dfs1(root, -1, 1); 192 dfs2(root, root); 193 194 // for(ll i = 1;i <= n;i++) prllf("%d ", id[i]); 195 // prllf("\n"); 196 // for(ll i = 1;i <= n;i++) prllf("%d ", rk[i]); 197 // prllf("\n"); 198 199 st.build(1, 1, n); 200 //scanf("%d", &q); 201 while(q--) 202 { 203 int op; 204 scanf("%d", &op); 205 if(op == 1) 206 { 207 int u, v, add; 208 scanf("%d%d%d", &u, &v, &add); 209 update_add(u, v, add); 210 } 211 else if(op == 2) 212 { 213 int u, v; 214 scanf("%d%d", &u, &v); 215 printf("%lld\n", querysum(u, v)); 216 } 217 else if(op == 3) 218 { 219 int u, add; 220 scanf("%d%d", &u, &add); 221 st.update(1, 1, n, id[u], id[u]+size[u]-1, add); 222 } 223 else 224 { 225 int u; 226 scanf("%d", &u); 227 printf("%lld\n",st.query(1, 1, n, 0, id[u], id[u]+size[u]-1)); 228 } 229 //st.prll_debug(1, 1, n); 230 } 231 }
个性签名:时间会解决一切