BZOJ 1036: [ZJOI2008]树的统计Count(树链剖分+单点更新+区间求和+区间求最大值)
题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=1036
题意:略。
题解:树链剖分模版,注意一些细节即可。
#include <iostream> #include <cstring> #include <cstdio> using namespace std; const int M = 3e4 + 10; struct Edge { int v , next; }edge[M << 1]; int head[M] , e; int top[M]; int fa[M]; int p[M]; int fp[M]; int deep[M]; int num[M]; int son[M]; int pos; void init() { memset(head , -1 , sizeof(head)); memset(son , -1 , sizeof(son)); e = 0; pos = 1; } void add(int u , int v) { edge[e].v = v; edge[e].next = head[u]; head[u] = e++; } void dfs1(int u , int pre , int d) { deep[u] = d; fa[u] = pre; num[u] = 1; for(int i = head[u] ; i != -1 ; i = edge[i].next) { int v = edge[i].v; if(v != pre) { dfs1(v , u , d + 1); num[u] += num[v]; if(son[u] == -1 || num[son[u]] < num[v]) { son[u] = v; } } } } void getpos(int u , int sp) { top[u] = sp; p[u] = pos++; fp[p[u]] = u; if(son[u] == -1) return ; getpos(son[u] , sp); for(int i = head[u] ; i != -1 ; i = edge[i].next) { int v = edge[i].v; if(v != fa[u] && v != son[u]) getpos(v , v); } } struct TnT { int l , r , sum , MAX; }T[M << 2]; int a[M]; void pushup(int i) { T[i].sum = T[i << 1].sum + T[(i << 1) | 1].sum; T[i].MAX = max(T[i << 1].MAX , T[(i << 1) | 1].MAX); } void build(int l , int r , int i) { int mid = (l + r) >> 1; T[i].l = l , T[i].r = r , T[i].MAX = 0 , T[i].sum = 0; if(l == r) { T[i].MAX = a[fp[l]]; T[i].sum = a[fp[l]]; return ; } build(l , mid , i << 1); build(mid + 1 , r , (i << 1) | 1); pushup(i); } void updata(int i , int pos , int ad) { int mid = (T[i].l + T[i].r) >> 1; if(T[i].l == T[i].r && T[i].l == pos) { T[i].MAX = ad; T[i].sum = ad; return ; } if(mid < pos) { updata((i << 1) | 1 , pos , ad); } else { updata(i << 1 , pos , ad); } pushup(i); } int queryM(int l , int r , int i) { int mid = (T[i].l + T[i].r) >> 1; if(T[i].l == l && T[i].r == r) { return T[i].MAX; } pushup(i); if(mid < l) { return queryM(l , r , (i << 1) | 1); } else if(mid >= r) { return queryM(l , r , i << 1); } else { return max(queryM(l , mid , i << 1) , queryM(mid + 1 , r , (i << 1) | 1)); } } int queryS(int l , int r , int i) { int mid = (T[i].l + T[i].r) >> 1; if(T[i].l == l && T[i].r == r) { return T[i].sum; } pushup(i); if(mid < l) { return queryS(l , r , (i << 1) | 1); } else if(mid >= r) { return queryS(l , r , i << 1); } else { return queryS(l , mid , i << 1) + queryS(mid + 1 , r , (i << 1) | 1); } } int findM(int u , int v) { int f1 = top[u] , f2 = top[v]; int tmp = -30010; while(f1 != f2) { if(deep[f1] < deep[f2]) { swap(f1 , f2); swap(u , v); } tmp = max(tmp , queryM(p[f1] , p[u] , 1)); u = fa[f1] , f1 = top[u]; } if(deep[u] > deep[v]) swap(u , v); return max(tmp , queryM(p[u] , p[v] , 1)); } int findS(int u , int v) { int f1 = top[u] , f2 = top[v]; int tmp = 0; while(f1 != f2) { if(deep[f1] < deep[f2]) { swap(f1 , f2); swap(u , v); } tmp += queryS(p[f1] , p[u] , 1); u = fa[f1] , f1 = top[u]; } if(deep[u] > deep[v]) swap(u , v); return tmp + queryS(p[u] , p[v] , 1); } int main() { int n , u , v , m; scanf("%d" , &n); init(); for(int i = 0 ; i < n - 1 ; i++) { scanf("%d%d" , &u , &v); add(u , v); add(v , u); } for(int i = 1 ; i <= n ; i++) { scanf("%d" , &a[i]); } dfs1(1 , 0 , 0); getpos(1 , 1); build(1 , pos , 1); scanf("%d" , &m); char cp[10]; while(m--) { scanf("%s" , cp); scanf("%d%d" , &u , &v); if(cp[0] == 'Q') { if(cp[1] == 'M') { printf("%d\n" , findM(u , v)); } else { printf("%d\n" , findS(u , v)); } } else { updata(1 , p[u] , v); } } return 0; }