BZOJ 1036 树的统计(树链剖分)
PS:树链剖分的很基本的题
1 #include <bits/stdc++.h> 2 3 using namespace std; 4 5 #define rep(i, a, b) for (int i(a); i <= (b); ++i) 6 #define dec(i, a, b) for (int i(a); i >= (b); --i) 7 #define lson i << 1, L, mid 8 #define rson i << 1 | 1, mid + 1, R 9 10 const int N = 30010; 11 12 struct node{ 13 int sum, mx; 14 } tree[N << 2]; 15 16 int f[N], fp[N], son[N], deep[N], father[N], sz[N], a[N], top[N]; 17 int q, tot, n; 18 char s[10]; 19 vector <int> v[N]; 20 21 void dfs1(int x, int fa, int dep){ 22 deep[x] = dep; 23 father[x] = fa; 24 son[x] = 0; 25 sz[x] = 1; 26 for (auto u : v[x]){ 27 if (u == fa) continue; 28 dfs1(u, x, dep + 1); 29 sz[x] += sz[u]; 30 if (sz[son[x]] < sz[u]) son[x] = u; 31 } 32 } 33 34 void dfs2(int x, int tp){ 35 top[x] = tp; 36 f[x] = ++tot; 37 fp[f[x]] = x; 38 if (son[x]) dfs2(son[x], tp); 39 for (auto u : v[x]){ 40 if (u == father[x] || u == son[x]) continue; 41 dfs2(u, u); 42 } 43 } 44 45 inline void pushup(int i){ 46 tree[i].sum = tree[i << 1].sum + tree[i << 1 | 1].sum; 47 tree[i].mx = max(tree[i << 1].mx, tree[i << 1 | 1].mx); 48 } 49 50 51 void build(int i, int L, int R){ 52 53 if (L == R){ 54 tree[i].sum = tree[i].mx = a[fp[L]]; 55 return; 56 } 57 58 int mid = (L + R) >> 1; 59 build(lson); 60 build(rson); 61 pushup(i); 62 } 63 64 void update(int i, int L, int R, int pos, int val){ 65 if (L == R && L == pos){ 66 tree[i].sum = tree[i].mx = val; 67 return ; 68 } 69 70 int mid = (L + R) >> 1; 71 if (pos <= mid) update(lson, pos, val); 72 if (pos > mid) update(rson, pos, val); 73 74 pushup(i); 75 } 76 77 int query_max(int i, int L, int R, int l, int r){ 78 if (L == l && R == r) return tree[i].mx; 79 int mid = (L + R) >> 1; 80 if (r <= mid) return query_max(lson, l, r); 81 else if (l > mid) return query_max(rson, l, r); 82 else return max(query_max(lson, l, mid), query_max(rson, mid + 1, r)); 83 } 84 85 int query_sum(int i, int L, int R, int l, int r){ 86 if (L == l && R == r) return tree[i].sum; 87 int mid = (L + R) >> 1; 88 if (r <= mid) return query_sum(lson, l, r); 89 else if (l > mid) return query_sum(rson, l, r); 90 else return query_sum(lson, l, mid) + query_sum(rson, mid + 1, r); 91 } 92 93 int find_max(int x, int y){ 94 int f1 = top[x], f2 = top[y], ret = -(1 << 30); 95 for (; f1 != f2; ){ 96 if (deep[f1] < deep[f2]) swap(f1, f2), swap(x, y); 97 ret = max(ret, query_max(1, 1, n, f[f1], f[x])); 98 x = father[f1], f1 = top[x]; 99 } 100 101 if (x == y) return max(ret, query_max(1, 1, n, f[x], f[y])); 102 if (deep[x] > deep[y]) swap(x, y); 103 return max(ret, query_max(1, 1, n, f[x], f[y])); 104 } 105 106 int find_sum(int x, int y){ 107 int f1 = top[x], f2 = top[y], ret = 0; 108 for (; f1 != f2; ){ 109 if (deep[f1] < deep[f2]) swap(f1, f2), swap(x, y); 110 ret += query_sum(1, 1, n, f[f1], f[x]); 111 x = father[f1], f1 = top[x]; 112 } 113 114 if (x == y) return ret + query_sum(1, 1, n, f[x], f[y]); 115 if (deep[x] > deep[y]) swap(x, y); 116 return ret + query_sum(1, 1, n, f[x], f[y]); 117 } 118 119 int main(){ 120 121 scanf("%d", &n); 122 rep(i, 1, n - 1){ 123 int x, y; 124 scanf("%d%d", &x, &y); 125 v[x].push_back(y); 126 v[y].push_back(x); 127 } 128 129 rep(i, 1, n) scanf("%d", a + i); 130 dfs1(1, 0, 0); 131 dfs2(1, 1); 132 build(1, 1, n); 133 134 scanf("%d", &q); 135 while (q--){ 136 int x, y; 137 scanf("%s%d%d", s, &x, &y); 138 if (s[0] == 'C') update(1, 1, n, f[x], y); 139 if (s[0] == 'Q' && s[1] == 'M') printf("%d\n", find_max(x, y)); 140 if (s[0] == 'Q' && s[1] == 'S') printf("%d\n", find_sum(x, y)); 141 } 142 143 return 0; 144 }