比较基础的一道树链剖分的题 大概还是得说说思路
树链剖分是将树剖成很多条链,比较常见的剖法是按儿子的size来剖分,剖分完后对于这课树的询问用线段树维护——比如求路径和的话——随着他们各自的链向上走,直至他们在同一条链上为止。比较像lca的方法,只不过这里是按链为单位,而且隔壁的SymenYang说可以用树链剖分做lca。。吓哭
然后说说惨痛的调题经历:边表一定要开够啊! 不是n-1 而是2*(n-1)啊! 然后写变量时原始值和映射值要搞清楚啊! 不要搞错了! 还有就是下次求最小值一定看清下界是多少! 树的统计是-30000 ~ 30000 ,我果断naive 的写了一个初值为0!!! wa 0 就是这么痛苦! 还是too Young too Simple !
code :
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 using namespace std; 6 7 const int maxn = 50001; 8 9 struct edge{ 10 int t; edge* next; 11 }e[maxn*3], *head[maxn];int ne = 0; 12 13 void addedge(int f, int t){ 14 e[ne].t = t; e[ne].next = head[f]; head[f] = e + ne ++; 15 } 16 17 int n; 18 int size[maxn],fa[maxn],dep[maxn],w[maxn],un[maxn],map[maxn]; 19 20 struct node{ 21 int smax, sum; 22 node *l, *r; 23 }tr[maxn * 3], *root; int trne = 0; 24 25 node* build(int l, int r){ 26 node* x = tr + trne ++; 27 if(l != r) { 28 int mid = (l + r) >> 1; 29 x-> l = build(l, mid); 30 x-> r = build(mid + 1, r); 31 } 32 return x; 33 } 34 35 void update(node* x){ 36 if(x-> l) { 37 x-> sum = x-> l-> sum + x-> r->sum; 38 x-> smax = max(x-> l-> smax, x-> r-> smax); 39 } 40 } 41 42 void insert(node* x, int l, int r, int pos, int v) { 43 if(l == r) { x-> sum = v, x-> smax = v;} 44 else{ 45 int mid = (l + r) >> 1; 46 if(pos <= mid) insert(x-> l, l, mid, pos, v); 47 else insert(x-> r, mid + 1, r, pos, v); 48 update(x); 49 } 50 } 51 52 int ask(node* x, int l, int r, int ls, int rs, int flag) { 53 if(l == ls && r == rs) { 54 if(flag == 0) return x-> smax; 55 else return x-> sum; 56 } 57 else { 58 int mid = (l + r) >> 1; 59 if(rs <= mid) return ask(x-> l, l, mid, ls, rs, flag); 60 else if(ls >= mid + 1) return ask(x-> r, mid + 1, r, ls, rs, flag); 61 else { 62 if(flag == 0) 63 return max(ask(x->l, l, mid, ls, mid, flag), ask(x-> r, mid + 1, r, mid + 1, rs, flag)); 64 else 65 return ask(x-> l, l, mid, ls, mid, flag) + ask(x-> r, mid + 1, r, mid + 1, rs, flag); 66 } 67 } 68 } 69 70 int cnt = 0; 71 72 void size_cal(int x, int pre) { 73 if(pre == -1) dep[x] = 1, fa[x] = x; 74 else dep[x] = dep[pre] + 1, fa[x] = pre; 75 76 size[x] = 1; 77 for(edge* p = head[x]; p; p = p-> next) 78 if(dep[p-> t] == -1)size_cal(p-> t, x), size[x] += size[p-> t]; 79 } 80 81 void divide(int x, int pre){ 82 if(pre == -1) un[x] = x; 83 else un[x] = un[pre]; 84 map[x] = ++ cnt; insert(root, 1, n, map[x], w[x]); 85 int tmax = -1, ts = -1; 86 for(edge* p = head[x]; p; p = p-> next) { 87 if(dep[p-> t] > dep[x] && size[p-> t] > tmax) tmax = size[p-> t], ts = p-> t; 88 } 89 if(ts != -1) { 90 divide(ts, x); 91 for(edge* p = head[x]; p; p = p-> next) { 92 if(dep[p-> t] > dep[x] && p-> t != ts) divide(p-> t, -1); 93 } 94 } 95 } 96 97 void read() { 98 memset(dep,-1,sizeof(dep)); 99 scanf("%d", &n); 100 root = build(1, n); 101 for(int i = 1; i <= n - 1; i++) { 102 int f, t; 103 scanf("%d%d", &f, &t); 104 addedge(f, t), addedge(t, f); 105 } 106 for(int i = 1; i <= n; ++ i) { 107 scanf("%d", &w[i]); 108 } 109 size_cal(1, -1);divide(1, -1); 110 } 111 112 int sovmax(int a, int b) { 113 int ans = -30001; int ls, rs; 114 while(un[a] != un[b]) { 115 if(dep[un[a]] > dep[un[b]]) { 116 ls = map[a]; rs = map[un[a]]; 117 if(ls > rs) swap(ls, rs); 118 ans = max(ans, ask(root, 1, n, ls, rs, 0)); 119 a = fa[un[a]]; 120 } 121 else { 122 ls = map[b]; rs = map[un[b]]; 123 if(ls > rs) swap(ls, rs); 124 ans = max(ans, ask(root, 1, n, ls, rs, 0)); 125 b = fa[un[b]]; 126 } 127 } 128 ls = map[a], rs = map[b]; 129 if(ls > rs) swap(ls,rs); 130 ans = max(ans, ask(root, 1, n, ls, rs, 0)); 131 return ans; 132 } 133 134 int sovsum(int a,int b) { 135 int ans = 0; int ls, rs; 136 while(un[a] != un[b]) { 137 if(dep[un[a]] > dep[un[b]]) { 138 ls = map[a], rs = map[un[a]]; 139 if(ls > rs) swap(ls, rs); 140 ans += ask(root, 1, n, ls, rs, 1); 141 a = fa[un[a]]; 142 } 143 else { 144 ls = map[b]; rs = map[un[b]]; 145 if(ls > rs) swap(ls, rs); 146 ans += ask(root, 1, n, ls, rs, 1); 147 b = fa[un[b]]; 148 } 149 } 150 ls = map[a], rs = map[b]; 151 if(ls > rs) swap(ls, rs); 152 ans += ask(root, 1, n, ls, rs, 1); 153 return ans; 154 } 155 156 void sov() { 157 int m; 158 scanf("%d", &m); 159 while(m --) { 160 char s[10]; int ls, rs; 161 scanf("%s %d%d", s + 1, &ls, &rs); 162 if(s[2] == 'M') printf("%d\n", sovmax(ls, rs)); 163 if(s[2] == 'S') printf("%d\n", sovsum(ls, rs)); 164 if(s[2] == 'H') insert(root, 1, n, map[ls], rs); 165 } 166 } 167 168 int main(){ 169 read();sov(); 170 return 0; 171 }