BZOJ 1036 [ZJOI2008]树的统计Count (树链剖分 - 点权剖分 - 单点权修改)
题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=1036
树链剖分模版题,打的时候注意点就行。做这题的时候,真的傻了,单词拼错检查了一个多小时...
代码如下:
1 //树链剖分 点权修改 修改单节点 2 #include <iostream> 3 #include <cstring> 4 #include <algorithm> 5 #include <cstdio> 6 using namespace std; 7 const int MAXN = 3e4 + 10; 8 struct data { 9 int to , next; 10 }edge[MAXN << 1]; 11 int head[MAXN] , cnt , tot; 12 int top[MAXN] , par[MAXN] , son[MAXN] , size[MAXN] , dep[MAXN]; 13 int id[MAXN] , fid[MAXN]; //id[i]表示i对应在线段树上的位置 fid[i]表示线段树位置是i的叶子 对应的节点 14 int a[MAXN]; 15 16 void init() { 17 tot = cnt = 0; 18 memset(head , -1 , sizeof(head)); 19 } 20 21 inline void add(int u , int v) { 22 edge[tot].next = head[u]; 23 edge[tot].to = v; 24 head[u] = tot++; 25 } 26 27 void dfs1(int u , int p , int d) { 28 dep[u] = d , size[u] = 1 , son[u] = u , par[u] = p; 29 for(int i = head[u] ; ~i ; i = edge[i].next) { 30 int v = edge[i].to; 31 if(v == p) 32 continue; 33 dfs1(v , u , d + 1); 34 if(size[v] >= size[son[u]] || son[u] == u) 35 son[u] = v; 36 size[u] += size[v]; 37 } 38 } 39 40 void dfs2(int u , int p , int t) { 41 top[u] = t , id[u] = ++cnt; 42 fid[cnt] = u; 43 if(son[u] != u) 44 dfs2(son[u] , u , t); 45 for(int i = head[u] ; ~i ; i = edge[i].next) { 46 int v = edge[i].to; 47 if(v == p || v == son[u]) 48 continue; 49 dfs2(v , u , v); 50 } 51 } 52 53 struct segtree { 54 int l , r; 55 int sum , Max; 56 }T[MAXN << 2]; 57 58 void build(int p , int l , int r) { 59 int mid = (l + r) >> 1; 60 T[p].l = l , T[p].r = r; 61 if(l == r) { 62 T[p].Max = T[p].sum = a[fid[l]]; // 63 return ; 64 } 65 build(p << 1 , l , mid); 66 build((p << 1)|1 , mid + 1 , r); 67 T[p].sum = T[p << 1].sum + T[(p << 1)|1].sum; 68 T[p].Max = max(T[p << 1].Max , T[(p << 1)|1].Max); 69 } 70 71 void updata(int p , int pos , int num) { 72 int mid = (T[p].l + T[p].r) >> 1; 73 if(T[p].l == T[p].r && T[p].l == pos) { 74 T[p].sum = T[p].Max = num; 75 return ; 76 } 77 if(pos <= mid) { 78 updata(p << 1 , pos , num); 79 } 80 else { 81 updata((p << 1)|1 , pos , num); 82 } 83 T[p].sum = T[p << 1].sum + T[(p << 1)|1].sum; 84 T[p].Max = max(T[p << 1].Max , T[(p << 1)|1].Max); 85 } 86 87 int query_sum(int p , int l , int r) { 88 int mid = (T[p].l + T[p].r) >> 1; 89 if(T[p].l == l && T[p].r == r) { 90 return T[p].sum; 91 } 92 if(r <= mid) { 93 return query_sum(p << 1 , l , r); 94 } 95 else if(l > mid) { 96 return query_sum((p << 1)|1 , l , r); 97 } 98 else { 99 return query_sum(p << 1 , l , mid) + query_sum((p << 1)|1 , mid + 1 , r); 100 } 101 } 102 103 int query_max(int p , int l , int r) { 104 int mid = (T[p].l + T[p].r) >> 1; 105 if(T[p].l == l && T[p].r == r) { 106 return T[p].Max; 107 } 108 if(r <= mid) { 109 return query_max(p << 1 , l , r); 110 } 111 else if(l > mid) { 112 return query_max((p << 1)|1 , l , r); 113 } 114 else { 115 return max(query_max(p << 1 , l , mid) , query_max((p << 1)|1 , mid + 1 , r)); 116 } 117 } 118 119 int find_max(int u , int v) { 120 int fu = top[u] , fv = top[v]; 121 int Max = -1000000000; 122 while(fu != fv) { 123 if(dep[fu] >= dep[fv]) { 124 Max = max(Max , query_max(1 , id[fu] , id[u])); 125 u = par[fu]; 126 fu = top[u]; 127 } 128 else { 129 Max = max(Max , query_max(1 , id[fv] , id[v])); 130 v = par[fv]; 131 fv = top[v]; 132 } 133 } 134 if(dep[u] >= dep[v]) { 135 return max(Max , query_max(1 , id[v] , id[u])); 136 } 137 else { 138 return max(Max , query_max(1 , id[u] , id[v])); 139 } 140 } 141 142 int find_sum(int u , int v) { 143 int fu = top[u] , fv = top[v]; 144 int sum = 0; 145 while(fu != fv) { 146 if(dep[fu] > dep[fv]) { 147 sum += query_sum(1 , id[fu] , id[u]); 148 u = par[fu]; 149 fu = top[u]; 150 } 151 else { 152 sum += query_sum(1 , id[fv] , id[v]); 153 v = par[fv]; 154 fv = top[v]; 155 } 156 } 157 if(dep[u] >= dep[v]) { 158 return (sum + query_sum(1 , id[v] , id[u])); 159 } 160 else { 161 return (sum + query_sum(1 , id[u] , id[v])); 162 } 163 } 164 165 int main() 166 { 167 int n , u , v; 168 while(~scanf("%d" , &n)) { 169 init(); 170 for(int i = 1 ; i < n ; ++i) { 171 scanf("%d %d" , &u , &v); 172 add(u , v); 173 add(v , u); 174 } 175 for(int i = 1 ; i <= n ; ++i) { 176 scanf("%d" , a + i); 177 } 178 cnt = 0; 179 dfs1(1 , 1 , 0); 180 dfs2(1 , 1 , 1); 181 build(1 , 1 , cnt); 182 int m; 183 char q[10]; 184 scanf("%d" , &m); 185 while(m--) { 186 scanf("%s%d%d" , q , &u , &v); 187 if(q[0] == 'C') { 188 updata(1 , id[u] , v); 189 } 190 else if(strcmp(q ,"QMAX") == 0) { 191 printf("%d\n" , find_max(u , v)); 192 } 193 else { 194 printf("%d\n", find_sum(u , v)); 195 } 196 } 197 } 198 return 0; 199 }