【bzoj1036】树的统计[ZJOI2008]树链剖分+线段树
题目传送门:1036: [ZJOI2008]树的统计Count
这道题是我第一次打树剖的板子,虽然代码有点长,但是“打起来很爽”,而且整道题只花了不到1.5h+,还是一遍过样例!一次提交AC!(难道前天膜cyc神犇真的起作用了?)
按捺不住激动的心情!!!
这道题是树剖+线段树模板题,树剖把树转化成线段后就可以单点修改+区间求和/区间求max了。
又臭又长的93行代码:
#include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #include<algorithm> using namespace std; struct point{ int sum,max; }a[200010]; int fir[30010],to[60010],ne[60010]; int vis[30010],fa[30010],dep[30010],size[30010],hson[30010],top[30010],id[30010]; int tot,n,m,q; char s[10]; void change(int now,int l,int r,int x,int k) { if(l==r)a[now].sum=a[now].max=k; else{ int mid=(l+r)>>1; if(x<=mid)change(now<<1,l,mid,x,k);else change(now<<1|1,mid+1,r,x,k); a[now].sum=a[now<<1].sum+a[now<<1|1].sum; a[now].max=max(a[now<<1].max,a[now<<1|1].max); } } int getsum(int now,int l,int r,int x,int y) { if(x<=l&&r<=y)return a[now].sum; else{ int mid=(l+r)>>1,sum=0; if(x<=mid)sum+=getsum(now<<1,l,mid,x,y); if(mid<y)sum+=getsum(now<<1|1,mid+1,r,x,y); return sum; } } int getmax(int now,int l,int r,int x,int y) { if(x<=l&&r<=y)return a[now].max; else{ int mid=(l+r)>>1,mx=-1<<30; if(x<=mid)mx=max(mx,getmax(now<<1,l,mid,x,y)); if(mid<y)mx=max(mx,getmax(now<<1|1,mid+1,r,x,y)); return mx; } } void add(int x,int y){to[++tot]=y; ne[tot]=fir[x]; fir[x]=tot;} void dfs1(int now,int d) { int i,mxson=0; size[now]=1; vis[now]=1; dep[now]=d; for(i=fir[now];i;i=ne[i]) if(!vis[to[i]]){ fa[to[i]]=now; dfs1(to[i],d+1); if(size[to[i]]>mxson)mxson=size[to[i]],hson[now]=to[i]; size[now]+=size[to[i]]; } } void dfs2(int now,int tp) { int i; top[now]=tp; id[now]=++tot; if(hson[now])dfs2(hson[now],tp); for(i=fir[now];i;i=ne[i]) if(to[i]!=hson[now]&&to[i]!=fa[now])dfs2(to[i],to[i]); } int main() { int i,x,y; scanf("%d",&n); tot=0; for(i=1;i<n;i++)scanf("%d%d",&x,&y),add(x,y),add(y,x); for(i=1;i<=n;i++)vis[i]=0; tot=0; dfs1(1,1); dfs2(1,1); for(i=1;i<=n;i++)scanf("%d",&x),change(1,1,n,id[i],x); scanf("%d",&q); for(i=1;i<=q;i++){ scanf("%s%d%d",s,&x,&y); if(s[0]=='C')change(1,1,n,id[x],y); else if(s[1]=='M'){ int mx=-1<<29; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]){int tmp=x; x=y; y=tmp;} mx=max(mx,getmax(1,1,n,id[top[x]],id[x])); x=fa[top[x]]; } if(dep[x]>dep[y]){int tmp=x; x=y; y=tmp;} printf("%d\n",max(mx,getmax(1,1,n,id[x],id[y]))); } else{ int sum=0; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]){int tmp=x; x=y; y=tmp;} sum+=getsum(1,1,n,id[top[x]],id[x]); x=fa[top[x]]; } if(dep[x]>dep[y]){int tmp=x; x=y; y=tmp;} printf("%d\n",sum+getsum(1,1,n,id[x],id[y])); } } }