[Bzoj1036][ZJOI2008]树的统计Count(树链剖分)
题目链接:https://www.lydsy.com/JudgeOnline/problem.php?id=1036
树链剖分的板子题,在bzoj上做到就当复习啦
1 #include<bits/stdc++.h> 2 #define lson l,mid,i<<1 3 #define rson mid+1,r,i<<1|1 4 using namespace std; 5 typedef long long ll; 6 const int maxn = 200010; 7 const int INF = 2e9; 8 struct node { 9 int s, e, next; 10 }edge[maxn * 2]; 11 int n, m; 12 int son[maxn], top[maxn], tid[maxn], fat[maxn], siz[maxn], dep[maxn], rak[maxn]; 13 int head[maxn], len, dfx; 14 //siz保存以i为根的子树节点个数,top保存i节点所在链的顶端节点,son保存i节点的重儿子,fat保存i节点的父亲节点\ 15 dep保存i节点的深度(根为1),,tid保存i节点dfs后的新编号,rak保存新编号i对应的节点(rak[i]=j,tid[j]=i)。 16 void init() { 17 memset(head, -1, sizeof(head)); 18 len = 0, dfx = 0; 19 } 20 void add(int s, int e) {//邻接表存值 21 edge[len].s = s; 22 edge[len].e = e; 23 edge[len].next = head[s]; 24 head[s] = len++; 25 } 26 void dfs1(int x, int fa, int d) { 27 dep[x] = d, siz[x] = 1, fat[x] = fa, son[x] = -1; 28 for (int i = head[x]; i != -1; i = edge[i].next) { 29 int y = edge[i].e; 30 if (y == fa) 31 continue; 32 dfs1(y, x, d + 1); 33 siz[x] += siz[y]; 34 if (son[x] == -1 || siz[y] > siz[son[x]]) 35 son[x] = y; 36 } 37 } 38 void dfs2(int x, int c) { 39 top[x] = c; 40 tid[x] = ++dfx; 41 rak[dfx] = x; 42 if (son[x] == -1) 43 return; 44 dfs2(son[x], c); 45 for (int i = head[x]; i != -1; i = edge[i].next) { 46 int y = edge[i].e; 47 if (y != fat[x] && y != son[x]) 48 dfs2(y, y); 49 } 50 } 51 ll a[maxn]; 52 ll sum[maxn * 4]; 53 ll Max[maxn * 4]; 54 void up(int i) { 55 sum[i] = sum[i << 1] + sum[i << 1 | 1]; 56 Max[i] = max(Max[i << 1], Max[i << 1 | 1]); 57 } 58 void build(int l, int r, int i) { 59 if (l == r) { 60 sum[i] = a[rak[l]]; 61 Max[i] = a[rak[l]]; 62 return; 63 } 64 int mid = l + r >> 1; 65 build(lson); 66 build(rson); 67 up(i); 68 } 69 void update(int t, int k, int l, int r, int i) { 70 if (l == r) { 71 sum[i] = k; 72 Max[i] = k; 73 return; 74 } 75 int mid = l + r >> 1; 76 if (t <= mid) 77 update(t, k, lson); 78 else 79 update(t, k, rson); 80 up(i); 81 } 82 ll queryM(int L, int R, int l, int r, int i) { 83 if (L <= l && r <= R) 84 return Max[i]; 85 int mid = l + r >> 1; 86 ll MAX = -INF; 87 if (L <= mid) 88 MAX = max(MAX, queryM(L, R, lson)); 89 if (R > mid) 90 MAX = max(MAX, queryM(L, R, rson)); 91 return MAX; 92 } 93 ll queryS(int L, int R, int l, int r, int i) { 94 if (L <= l && r <= R) 95 return sum[i]; 96 int mid = l + r >> 1; 97 ll ans = 0; 98 if (L <= mid) 99 ans += queryS(L, R, lson); 100 if (R > mid) 101 ans += queryS(L, R, rson); 102 return ans; 103 } 104 ll solve(int x, int y, int flg) { 105 ll ans; 106 if (flg) 107 ans = -INF; 108 else 109 ans = 0; 110 while (top[x] != top[y]) { 111 if (dep[top[x]] < dep[top[y]]) 112 swap(x, y); 113 if (flg) 114 ans = max(ans, queryM(tid[top[x]], tid[x], 1, n, 1)); 115 else 116 ans += queryS(tid[top[x]], tid[x], 1, n, 1); 117 x = fat[top[x]]; 118 } 119 if (dep[x] < dep[y]) 120 swap(x, y); 121 if (flg) 122 ans = max(ans, queryM(tid[y], tid[x], 1, n, 1)); 123 else 124 ans += queryS(tid[y], tid[x], 1, n, 1); 125 return ans; 126 } 127 int main() { 128 while (scanf("%d", &n) != EOF) { 129 init(); 130 int x, y; 131 for (int i = 0; i < n - 1; i++) { 132 scanf("%d%d", &x, &y); 133 add(x, y); 134 add(y, x); 135 } 136 for (int i = 1; i <= n; i++) 137 scanf("%lld", &a[i]); 138 dfs1(1, 0, 0); 139 dfs2(1, 1); 140 build(1, n, 1); 141 scanf("%d", &m); 142 while (m--) { 143 char s[10]; 144 scanf("%s", s); 145 if (strcmp(s, "QMAX") == 0) { 146 scanf("%d%d", &x, &y); 147 printf("%lld\n", solve(x, y, 1)); 148 } 149 else if (strcmp(s, "QSUM") == 0) { 150 scanf("%d%d", &x, &y); 151 printf("%lld\n", solve(x, y, 0)); 152 } 153 else { 154 scanf("%d%d", &x, &y); 155 update(tid[x], y, 1, n, 1); 156 } 157 } 158 } 159 }