[BZOJ1036][ZJOI2008]树的统计Count(树链剖分)
树剖模板题了,
Code
#include <cstdio> #include <algorithm> #define MID int mid=(l+r)>>1,ls=id<<1,rs=id<<1|1 #define N 30010 using namespace std; struct node{int sum,mx;node(){sum=0,mx=-1e9;}}T[N*4]; struct info{int to,nex;}e[N*2]; int n,A[N],tot,head[N],tag[N*4]; int dep[N],fa[N],sz[N],son[N]; int cnt,tp[N],tw[N],tid[N]; inline int read(){ int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } inline void Link(int u,int v){ e[++tot].nex=head[u];e[tot].to=v;head[u]=tot; } void update(int l,int r,int id,int p,int x){ if(l==r){T[id].sum=T[id].mx=x;return;} MID; if(p<=mid) update(l,mid,ls,p,x); else update(mid+1,r,rs,p,x); T[id].sum=T[ls].sum+T[rs].sum; T[id].mx=max(T[ls].mx,T[rs].mx); } int querySum(int l,int r,int id,int L,int R){ if(L<=l&&r<=R) return T[id].sum; MID; int res=0; if(L<=mid) res+=querySum(l,mid,ls,L,R); if(R>mid) res+=querySum(mid+1,r,rs,L,R); return res; } int queryMx(int l,int r,int id,int L,int R){ if(L<=l&&r<=R) return T[id].mx; MID; int res=-1e9; if(L<=mid) res=max(res,queryMx(l,mid,ls,L,R)); if(R>mid) res=max(res,queryMx(mid+1,r,rs,L,R)); return res; } void dfs(int u,int pre){ sz[u]=1; for(int i=head[u],mx=0;i;i=e[i].nex){ int v=e[i].to; if(v==pre) continue; dep[v]=dep[u]+1; fa[v]=u; dfs(v,u); sz[u]+=sz[v]; if(sz[v]>mx) son[u]=v,mx=sz[v]; } } void dddfs(int u,int top){ tp[u]=top; tid[u]=++cnt; tw[cnt]=A[u]; if(!son[u]) return; dddfs(son[u],top); for(int i=head[u];i;i=e[i].nex){ int v=e[i].to; if(v!=son[u]&&v!=fa[u]) dddfs(v,v); } } inline int qRangeSum(int u,int v){ int res=0; while(tp[u]!=tp[v]){ if(dep[tp[u]]<dep[tp[v]]) swap(u,v); res+=querySum(1,n,1,tid[tp[u]],tid[u]); u=fa[tp[u]]; } if(dep[u]>dep[v]) swap(u,v); res+=querySum(1,n,1,tid[u],tid[v]); return res; } inline int qRangeMx(int u,int v){ int res=-1e9; while(tp[u]!=tp[v]){ if(dep[tp[u]]<dep[tp[v]]) swap(u,v); res=max(res,queryMx(1,n,1,tid[tp[u]],tid[u])); u=fa[tp[u]]; } if(dep[u]>dep[v]) swap(u,v); res=max(res,queryMx(1,n,1,tid[u],tid[v])); return res; } void build(int l,int r,int id){ if(l==r){ T[id].sum=T[id].mx=tw[l]; return; } MID; build(l,mid,ls); build(mid+1,r,rs); T[id].sum=T[ls].sum+T[rs].sum; T[id].mx=max(T[ls].mx,T[rs].mx); } inline void Init(){ n=read(); for(int i=1;i<n;++i){ int u=read(),v=read(); Link(u,v),Link(v,u); } for(int i=1;i<=n;A[i++]=read()); dfs(1,0); dddfs(1,1); build(1,n,1); } inline void solve(){ int m=read(),x,y; char s[10]; while(m--){ scanf("%s%d%d\n",s,&x,&y); if(s[1]=='H') update(1,n,1,tid[x],y); else if(s[1]=='M') printf("%d\n",qRangeMx(x,y)); else printf("%d\n",qRangeSum(x,y)); } } int main(){Init();solve();}