树的统计 树链剖分
Code:
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; typedef long long ll; const int maxn=30000+5; int head[maxn],nex[maxn*2],to[maxn*2]; int dep[maxn],siz[maxn],son[maxn],p[maxn],top[maxn]; int A[maxn],maxv[maxn*4+3], val[maxn]; int val2[maxn]; ll sumv[maxn*4+3]; int _max; int cnt,n,cnt2; void addedge(int u,int v) { nex[++cnt]=head[u]; head[u]=cnt,to[cnt]=v; } void dfs1(int u,int fa,int cur){ p[u]=fa,dep[u]=cur,siz[u]=1; for(int i=head[u];i;i=nex[i]) { int v=to[i]; if(v!=fa) { dfs1(v,u,cur+1); siz[u]+=siz[v]; if(son[u]==-1||siz[v]>siz[son[u]])son[u]=v; } } } void dfs2(int u,int tp) { top[u]=tp,A[u]=++cnt2; if(son[u]>0)dfs2(son[u],tp); for(int i=head[u];i;i=nex[i]) { int v=to[i]; if(v!=p[u]&&v!=son[u])dfs2(v,v); } } void build_tree(int L,int R,int o,int arr[]) { if(L==R){ sumv[o]=maxv[o]=arr[L]; return; } int mid=(L+R)/2; build_tree(L,mid,o*2,arr); build_tree(mid+1,R,o*2+1,arr); sumv[o]=sumv[o*2]+sumv[o*2+1]; maxv[o]=max(maxv[o*2],maxv[o*2+1]); } void update(int l,int k,int L,int R,int o) { if(L==R) { sumv[o]=maxv[o]=k; return; } int mid=(L+R)/2; if(l<=mid)update(l,k,L,mid,o*2); else update(l,k,mid+1,R,o*2+1); sumv[o]=sumv[o*2]+sumv[o*2+1]; maxv[o]=max(maxv[o*2],maxv[o*2+1]); } ll query(int l,int r,int L,int R,int o) { if(l<=L&&r>=R) { _max=max(_max,maxv[o]); return sumv[o]; } int mid=(L+R)/2; ll ret=0; if(l<=mid)ret+=query(l,r,L,mid,o*2); if(r>mid)ret+=query(l,r,mid+1,R,o*2+1); return ret; } ll lca(int x,int y) { _max=-300000; ll ret=0; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]]){ret+=query(A[top[y]],A[y],1,n,1);y=p[top[y]];} else {ret+=query(A[top[x]],A[x],1,n,1);x=p[top[x]];} } if(dep[x]<dep[y])ret+=query(A[x],A[y],1,n,1); else ret+=query(A[y],A[x],1,n,1); return ret; } int main() { memset(son,-1,sizeof(son)); scanf("%d",&n); for(int i=1;i<n;++i){ int a,b; scanf("%d%d",&a,&b); addedge(a,b); addedge(b,a); } for(int i=1;i<=n;++i)scanf("%d",&val[i]); dfs1(1,-1,1); //以一为根 dfs2(1,1); for(int i=1;i<=n;++i)val2[A[i]]=val[i]; build_tree(1,n,1,val2); //建树 int T; scanf("%d",&T); while(T--) { char S[20]; scanf("%s",S); if(S[0]=='C'){ int u,t; scanf("%d%d",&u,&t); update(A[u],t,1,n,1); } if(S[1]=='M'){ int x,y; scanf("%d%d",&x,&y); lca(x,y); printf("%d\n",_max); } if(S[1]=='S') { int x,y; scanf("%d%d",&x,&y); printf("%lld\n",lca(x,y)); } } return 0; }