【BZOJ】1036 [ZJOI2008]树的统计Count
【算法】树链剖分+线段树
【题解】模板题,见http://www.cnblogs.com/onioncyc/p/6207462.html
调用线段数时要用新编号pos[i] !!!
#include<cstdio> #include<cctype> #include<algorithm> using namespace std; const int maxn=30010,inf=0x3f3f3f3f; int pos[maxn],top[maxn],dfsnum,f[maxn],deep[maxn],size[maxn],first[maxn],n,tot,a[maxn]; struct edge{int u,v,from;}e[maxn*3]; struct tree{int l,r,sum,mx;}t[maxn*3]; int read() { char c;int s=0,t=1; while(!isdigit(c=getchar()))if(c=='-')t=-1; do{s=s*10+c-'0';}while(isdigit(c=getchar())); return s*t; } void insert(int u,int v) {tot++;e[tot].u=u;e[tot].v=v;e[tot].from=first[u];first[u]=tot;} void dfs1(int x,int fa) { size[x]=1; for(int i=first[x];i;i=e[i].from) if(e[i].v!=fa) { int y=e[i].v; deep[y]=deep[x]+1; f[y]=x; dfs1(y,x); size[x]+=size[y]; } } void dfs2(int x,int tp,int fa) { int k=0; pos[x]=++dfsnum; top[x]=tp; for(int i=first[x];i;i=e[i].from) if(e[i].v!=fa&&size[e[i].v]>size[k])k=e[i].v; if(k==0)return; dfs2(k,tp,x); for(int i=first[x];i;i=e[i].from) if(e[i].v!=fa&&e[i].v!=k)dfs2(e[i].v,e[i].v,x); } void build(int k,int l,int r) { t[k].l=l;t[k].r=r; if(l!=r) { int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); } } void change(int k,int x,int y) { int left=t[k].l,right=t[k].r; if(left==right){t[k].mx=y;t[k].sum=y;} else { int mid=(left+right)>>1; if(x<=mid)change(k<<1,x,y); else change(k<<1|1,x,y); t[k].sum=t[k<<1].sum+t[k<<1|1].sum; t[k].mx=max(t[k<<1].mx,t[k<<1|1].mx); } } int ask_mx(int k,int l,int r) { int left=t[k].l,right=t[k].r; if(l<=left&&right<=r)return t[k].mx; else { int mid=(left+right)>>1,maxs=-inf; if(l<=mid)maxs=ask_mx(k<<1,l,r); if(r>mid)maxs=max(maxs,ask_mx(k<<1|1,l,r)); return maxs; } } int ask_sum(int k,int l,int r) { int left=t[k].l,right=t[k].r; if(l<=left&&right<=r)return t[k].sum; else { int mid=(left+right)>>1,sums=0; if(l<=mid)sums=ask_sum(k<<1,l,r); if(r>mid)sums+=ask_sum(k<<1|1,l,r); return sums; } } int solve_mx(int x,int y) { int maxs=-inf; while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]])swap(x,y); maxs=max(maxs,ask_mx(1,pos[top[x]],pos[x])); x=f[top[x]]; } if(pos[x]>pos[y])swap(x,y); maxs=max(maxs,ask_mx(1,pos[x],pos[y])); return maxs; } int solve_sum(int x,int y) { int sums=0; while(top[x]!=top[y]) { if(deep[top[x]]<deep[top[y]])swap(x,y); sums+=ask_sum(1,pos[top[x]],pos[x]); x=f[top[x]]; } if(pos[x]>pos[y])swap(x,y); sums+=ask_sum(1,pos[x],pos[y]); return sums; } int main() { n=read(); for(int i=1;i<n;i++) { int u=read(),v=read(); insert(u,v); insert(v,u); } for(int i=1;i<=n;i++)a[i]=read(); dfs1(1,-1); dfs2(1,1,-1); build(1,1,n); for(int i=1;i<=n;i++)change(1,pos[i],a[i]); int Q=read();char ch[10]; for(int i=1;i<=Q;i++) { scanf("%s",ch); int u=read(),v=read(); if(ch[1]=='H')change(1,pos[u],v);//QAQ 调用线段树必须用新编号,下面用旧编号是因为子程序中用了新编号T_T if(ch[1]=='M')printf("%d\n",solve_mx(u,v)); if(ch[1]=='S')printf("%d\n",solve_sum(u,v)); } return 0; }