BZOJ 1036 树链剖分模板题
题意:一棵树,每个点有权值,三种操作:修改一个点的值;询问一条链上最大值;询问一条链上权值和。
tags:模板题
// bzoj 1036 #include<bits/stdc++.h> using namespace std; #pragma comment(linker, "/STACK:102400000,102400000") #define FF(i,a,b) for (int i=a;i<=b;i++) #define F(i,b,a) for (int i=b;i>=a;i--) #define mes(a,b) memset(a,b,sizeof(a)) #define INF 0x3f3f3f3f typedef long long ll; const int N = 3e4+10; struct Edge{int to, next;}e[N]; struct Point{int mx, sum;}t[N<<2]; int n, sz, Size[N], dep[N], w[N], head[N], pos[N], tot, bl[N], fa[N]; void Addedge(int u, int v) { e[tot].to=v, e[tot].next=head[u], head[u]=tot++; } void Init() { mes(head,-1); scanf("%d", &n); int u, v; FF(i,1,n-1) { scanf("%d %d", &u, &v); Addedge(u, v); Addedge(v, u); } } void dfs1(int u) //搜索出各个子树的Size { Size[u]=1; for(int i=head[u]; i!=-1; i=e[i].next) if(e[i].to!=fa[u]) { int v=e[i].to; fa[v]=u, dep[v]=dep[u]+1; dfs1(v); Size[u]+=Size[v]; } } void dfs2(int u, int chain) //构建树链 { sz++, pos[u]=sz, bl[u]=chain; //pos[]分配u结点在线段树中的编号,chain是一条链上头结点的标号 int mx=0; for(int i=head[u]; i!=-1; i=e[i].next) { int v=e[i].to; if(dep[v]>dep[u] && Size[v]>Size[mx]) mx=v; } if(mx==0) return ; dfs2(mx, chain); for(int i=head[u]; i!=-1; i=e[i].next) { int v=e[i].to; if(dep[v]>dep[u] && mx!=v) dfs2(v, v); } } void update(int ro, int l, int r, int x, int y) //线段树单点更新 { if(l==r) { t[ro].mx=t[ro].sum=y; return ; } int mid=(l+r)>>1; if(x<=mid) update(ro<<1, l, mid, x, y); else update(ro<<1|1, mid+1, r, x, y); t[ro].mx=max(t[ro<<1].mx, t[ro<<1|1].mx); t[ro].sum=t[ro<<1].sum+t[ro<<1|1].sum; } int querymx(int ro, int l, int r, int x, int y) //线段树区间求最大值 { if(l==x && r==y) return t[ro].mx; int mid=(l+r)>>1; if(y<=mid) return querymx(ro<<1, l, mid, x, y); else if(mid<x) return querymx(ro<<1|1, mid+1, r, x, y); else return max(querymx(ro<<1, l, mid, x, mid), querymx(ro<<1|1, mid+1, r, mid+1, y)); } int solvemx(int x, int y) { int mx=-INF; while(bl[x]!=bl[y]) { if(dep[bl[x]]<dep[bl[y]]) swap(x, y); mx=max(mx, querymx(1, 1, sz, pos[bl[x]], pos[x])); x=fa[bl[x]]; } if(pos[x]>pos[y]) swap(x, y); mx=max(mx, querymx(1, 1, sz, pos[x], pos[y])); return mx; } int querysum(int ro, int l, int r, int x, int y) //线段树区间求和 { if(l==x && r==y) return t[ro].sum; int mid=(l+r)>>1; if(y<=mid) return querysum(ro<<1, l, mid, x, y); else if(mid<x) return querysum(ro<<1|1, mid+1, r, x, y); else return querysum(ro<<1, l, mid, x, mid)+querysum(ro<<1|1, mid+1, r, mid+1, y); } int solvesum(int x, int y) { int sum=0; while(bl[x]!=bl[y]) { //不在一条重链上时,就一边修改,一边往同一条重链上靠 if(dep[bl[x]]<dep[bl[y]]) swap(x, y); sum+=querysum(1, 1, sz, pos[bl[x]], pos[x]); x=fa[bl[x]]; } if(pos[x]>pos[y]) swap(x, y); sum+=querysum(1, 1, sz, pos[x], pos[y]); return sum; } void solve() { FF(i,1,n) scanf("%d", &w[i]), update(1, 1, sz, pos[i], w[i]); int q, x, y; char str[10]; scanf("%d", &q); FF(i,1,q) { scanf("%s %d %d", str, &x, &y); if(str[0]=='C') w[x]=y, update(1, 1, sz, pos[x], y); else if(str[1]=='M') printf("%d\n", solvemx(x, y)); else printf("%d\n", solvesum(x, y)); } } int main() { Init(); dfs1(1); dfs2(1, 1); solve(); return 0; }