【BZOJ 1036】 树的统计count
题目
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
分析
树链剖分
代码
#include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<vector> using namespace std; #define MN 30000 #define fINF -30000000 int fa[MN+5],W[MN+5],ans[MN+5],dfn[MN+5],sons[MN+5],fl[MN+5],head[MN+5],ccnt=0,cnt=0; struct data{int to,next;}e[MN*2+10]; void ins(int u,int v){ e[++ccnt].to=v;e[ccnt].next=head[u];head[u]=ccnt; e[++ccnt].to=u;e[ccnt].next=head[v];head[v]=ccnt; } struct TREE{int val,max;}t[MN*4+10]; int n,q; void update(int k,int l,int r,int q,int v){ if(l==r) {t[k].val=t[k].max=v;return;} int mid=(l+r)/2; if(q<=mid) update(k<<1,l,mid,q,v); if(q>mid) update(k<<1|1,mid+1,r,q,v); t[k].val=t[k<<1].val+t[k<<1|1].val; t[k].max=max(t[k<<1].max,t[k<<1|1].max); } void dfs1(int x){ sons[x]=1; for(int i=head[x];i;i=e[i].next){ if(e[i].to==fa[x]) continue; fl[e[i].to]=fl[x]+1; fa[e[i].to]=x; dfs1(e[i].to); sons[x]+=sons[e[i].to]; } } void dfs2(int x,int chain){ int k=0; dfn[x]=++cnt; ans[x]=chain; for(int i=head[x];i;i=e[i].next) if(fl[e[i].to]>fl[x]&&sons[e[i].to]>sons[k]) k=e[i].to; if(k==0) return; dfs2(k,chain); for(int i=head[x];i;i=e[i].next) if(fl[e[i].to]>fl[x]&&k!=e[i].to) dfs2(e[i].to,e[i].to); } int query_max(int k,int l,int r,int a,int b){ if(a<=l&&r<=b) return t[k].max; int m=(l+r)/2,anss=fINF; if(a<=m) anss=max(anss,query_max(k<<1,l,m,a,b)); if(m<b) anss=max(anss,query_max(k<<1|1,m+1,r,a,b)); return anss; } int query_sum(int k,int l,int r,int a,int b){ if(a<=l&&r<=b) return t[k].val; int m=(l+r)/2,anss=0; if(a<=m) anss+=query_sum(k<<1,l,m,a,b); if(m<b) anss+=query_sum(k<<1|1,m+1,r,a,b); return anss; } int find_sum(int x,int y){ int sum=0; while(ans[x]!=ans[y]){ if(fl[ans[x]]<fl[ans[y]]) swap(x,y); sum+=query_sum(1,1,n,dfn[ans[x]],dfn[x]); x=fa[ans[x]]; } if(dfn[x]>dfn[y]) swap(x,y); sum+=query_sum(1,1,n,dfn[x],dfn[y]); return sum; } int find_max(int x,int y){ int mx=fINF; while(ans[x]!=ans[y]){ if(fl[ans[x]]<fl[ans[y]]) swap(x,y); mx=max(mx,query_max(1,1,n,dfn[ans[x]],dfn[x])); x=fa[ans[x]]; } if(dfn[x]>dfn[y]) swap(x,y); mx=max(mx,query_max(1,1,n,dfn[x],dfn[y])); return mx; } void solve(int k,int a,int b){ if(k==1) printf("%d\n",find_max(a,b)); if(k==2) printf("%d\n",find_sum(a,b)); if(k==3) update(1,1,n,dfn[a],b); } int main(){ int u,v,ro; scanf("%d",&n); for(int i=1;i<n;i++) scanf("%d%d",&u,&v),ins(u,v); dfs1(1); dfs2(1,1); for(int i=1;i<=n;i++) scanf("%d",&W[i]),update(1,1,n,dfn[i],W[i]); scanf("%d",&q); while(q--){ char ch=getchar(); int k; int x1=0,f1=1,x2=0,f2=1; while(ch<'0'||ch>'9'){ if(ch=='X') k=1; if(ch=='U') k=2; if(ch=='H') k=3; if(ch=='-') f1=-1; ch=getchar(); } while(ch>='0'&&ch<='9') x1=x1*10+ch-'0',ch=getchar(); while(ch<'0'||ch>'9') f2=ch=='-'?-1:1,ch=getchar(); while(ch>='0'&&ch<='9') x2=x2*10+ch-'0',ch=getchar(); solve(k,x1*f1,x2*f2); } return 0; }
————————————————————————————————————
来自PaperCloud的博客,未经允许,请勿转载,谢谢。
致虚极,守静笃,万物并作,吾以观其复