bzoj1036: [ZJOI2008]树的统计Count 树链剖分+线段树
入门题 + 熟悉代码
/************************************************************** Problem: 1036 User: 96655 Language: C++ Result: Accepted Time:2472 ms Memory:5536 kb ****************************************************************/ #include<cstdio> #include<cstring> #include<algorithm> #include<string> #include<iostream> #include<cstdlib> #include<queue> #include<map> #include<set> #include<vector> #include<cmath> #include<stack> #include<utility> using namespace std; const int maxn=30005; struct Edge { int v,next; } edge[maxn*2]; int va[maxn],p,head[maxn],clk,n; void init() { memset(head,-1,sizeof(head)); p=clk=0; } void addedge(int u,int v) { edge[p].v=v; edge[p].next=head[u]; head[u]=p++; } int id[maxn],sz[maxn],dep[maxn],fa[maxn],son[maxn],top[maxn]; int xx[maxn]; void dfs1(int u,int f,int d) { dep[u]=d; fa[u]=f; sz[u]=1; son[u]=-1; for(int i=head[u]; ~i; i=edge[i].next) { int v=edge[i].v; if(v==f)continue; dfs1(v,u,d+1); sz[u]+=sz[v]; if(son[u]==-1||sz[v]>sz[son[u]]) son[u]=v; } } void dfs2(int u,int tp) { id[u]=++clk; xx[id[u]]=va[u]; top[u]=tp; if(son[u]!=-1)dfs2(son[u],tp); for(int i=head[u]; ~i; i=edge[i].next) { int v=edge[i].v; if(v==fa[u]||v==son[u])continue; dfs2(v,v); } } int sum[maxn*4],maxv[maxn*4]; void pushup(int rt) { sum[rt]=sum[rt*2]+sum[rt*2+1]; maxv[rt]=max(maxv[rt*2],maxv[rt*2+1]); } void build(int rt,int l,int r) { if(l==r) { sum[rt]=maxv[rt]=xx[l]; return; } int m=(l+r)>>1; build(rt*2,l,m); build(rt*2+1,m+1,r); pushup(rt); } void update(int rt,int l,int r,int pos,int c) { if(l==r) { maxv[rt]=sum[rt]=c; return; } int m=(l+r)>>1; if(pos<=m)update(rt*2,l,m,pos,c); else update(rt*2+1,m+1,r,pos,c); pushup(rt); } int query1(int rt,int l,int r,int x,int y) { if(x<=l&&r<=y) { return sum[rt]; } int ans=0; int m=(l+r)>>1; if(x<=m)ans+=query1(rt*2,l,m,x,y); if(y>m)ans+=query1(rt*2+1,m+1,r,x,y); return ans; } int query2(int rt,int l,int r,int x,int y) { if(x<=l&&r<=y) { return maxv[rt]; } int ans=-99999; int m=(l+r)>>1; if(x<=m)ans=max(ans,query2(rt*2,l,m,x,y)); if(y>m)ans=max(ans,query2(rt*2+1,m+1,r,x,y)); return ans; } int getsum(int u,int v) { int ans=0; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); ans+=query1(1,1,n,id[top[u]],id[u]); u=fa[top[u]]; } if(dep[u]>dep[v])swap(u,v); ans+=query1(1,1,n,id[u],id[v]); return ans; } int getmax(int u,int v) { int ans=-99999; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); ans=max(ans,query2(1,1,n,id[top[u]],id[u])); u=fa[top[u]]; } if(dep[u]>dep[v])swap(u,v); ans=max(ans,query2(1,1,n,id[u],id[v])); return ans; } int main() { while(~scanf("%d",&n)) { init(); for(int i=1; i<n; ++i) { int u,v; scanf("%d%d",&u,&v); addedge(u,v); addedge(v,u); } for(int i=1; i<=n; i++) scanf("%d",&va[i]); dfs1(1,1,0); dfs2(1,1); build(1,1,n); int q; scanf("%d",&q); while(q--) { char s[20]; int x,y; scanf("%s%d%d",s,&x,&y); if(s[0]=='Q') { if(s[1]=='M') printf("%d\n",getmax(x,y)); else printf("%d\n",getsum(x,y)); } else update(1,1,n,id[x],y); } } return 0; }