BZOJ1036[ZJOI2008]树的统计Count 题解

题目大意:

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。有一些操作:1.把结点u的权值改为t;2.询问从点u到点v的路径上的节点的最大权值 3.询问从点u到点v的路径上的节点的权值和。

思路:

  进行轻重树链剖分,再根据每个节点的dfs序建立线段树,维护其最大值以及和,询问时用树剖后的结果将重链作为区间一段一段求和。

代码:

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<iostream>
  4 #define M 1000009
  5 using namespace std;
  6 
  7 int n,dfn,cnt,to[M],next[M],head[M],size[M],vis[M],deep[M],fa[M],top[M],w[M],mx[M],sum[M],id[M];
  8 
  9 void add(int x,int y)
 10 {
 11     to[++cnt]=y,next[cnt]=head[x],head[x]=cnt;
 12 }
 13 
 14 void dfs1(int x)
 15 {
 16     size[x]=vis[x]=1;
 17     for (int i=head[x];i;i=next[i])
 18         if (!vis[to[i]])
 19         {
 20             deep[to[i]]=deep[x]+1;
 21             fa[to[i]]=x;
 22             dfs1(to[i]);
 23             size[x]+=size[to[i]];
 24         }
 25 }
 26 
 27 void dfs2(int x,int chain)
 28 {
 29     int k=0,i;
 30     id[x]=++dfn;
 31     top[x]=chain;
 32     for (i=head[x];i;i=next[i])
 33         if (deep[to[i]]>deep[x] && size[to[i]]>size[k]) k=to[i];
 34     if (!k) return;
 35     dfs2(k,chain);
 36     for (i=head[x];i;i=next[i])
 37         if (deep[to[i]]>deep[x] && to[i]!=k) dfs2(to[i],to[i]);
 38 }
 39 
 40 int LCA(int x,int y)
 41 {
 42     for (;top[x]!=top[y];x=fa[top[x]])
 43         if (deep[top[x]]<deep[top[y]]) swap(x,y);
 44     return deep[x]<deep[y]?x:y;
 45 }
 46 
 47 void change(int l,int r,int x,int y,int cur)
 48 {
 49     if (l==r)
 50     {
 51         mx[cur]=sum[cur]=y;
 52         return;
 53     }
 54     int mid=l+r>>1;
 55     if (x<=mid) change(l,mid,x,y,cur<<1);
 56     else change(mid+1,r,x,y,cur<<1|1);
 57     mx[cur]=max(mx[cur<<1],mx[cur<<1|1]);
 58     sum[cur]=sum[cur<<1]+sum[cur<<1|1];
 59 }
 60 
 61 int SUM(int L,int R,int l,int r,int cur)
 62 {
 63     if (l<=L && r>=R) return sum[cur];
 64     int mid=L+R>>1;
 65     if (l>mid) return SUM(mid+1,R,l,r,cur<<1|1);
 66     else if (r<=mid) return SUM(L,mid,l,r,cur<<1);
 67          else return SUM(L,mid,l,r,cur<<1)+SUM(mid+1,R,mid+1,r,cur<<1|1);
 68 }
 69 
 70 int MAX(int L,int R,int l,int r,int cur)
 71 {
 72     if (l<=L && r>=R) return mx[cur];
 73     int mid=L+R>>1;
 74     if (l>mid) return MAX(mid+1,R,l,r,cur<<1|1);
 75     else if (r<=mid) return MAX(L,mid,l,r,cur<<1);
 76          else return max(MAX(L,mid,l,mid,cur<<1),MAX(mid+1,R,mid+1,r,cur<<1|1));
 77 }
 78 
 79 int Sum(int x,int y)
 80 {
 81     int ans=0;
 82     for (;top[x]!=top[y];x=fa[top[x]])
 83     {
 84         if (deep[top[x]]<deep[top[y]]) swap(x,y);
 85         ans+=SUM(1,n,id[top[x]],id[x],1);
 86     }
 87     if (deep[x]>deep[y]) swap(x,y);
 88     return ans+SUM(1,n,id[x],id[y],1);
 89 }
 90 
 91 int Max(int x,int y)
 92 {
 93     int ans=-999999999;
 94     for (;top[x]!=top[y];x=fa[top[x]])
 95     {
 96         if (deep[top[x]]<deep[top[y]]) swap(x,y);
 97         ans=max(ans,MAX(1,n,id[top[x]],id[x],1));
 98     }
 99     if (deep[x]>deep[y]) swap(x,y);
100     return max(ans,MAX(1,n,id[x],id[y],1));
101 }
102 
103 int main()
104 {
105     int m,i,x,y;
106     scanf("%d",&n);
107     for (i=1;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x);
108     dfs1(1);
109     dfs2(1,1);
110     for (i=1;i<=n;i++) scanf("%d",&w[i]),change(1,n,id[i],w[i],1);
111     scanf("%d",&m);
112     for (i=1;i<=m;i++)
113     {
114         char ch[9];
115         scanf("%s%d%d",ch,&x,&y);
116         if (ch[0]=='C') w[x]=y,change(1,n,id[x],y,1);
117         else
118             if (ch[1]=='S') printf("%d\n",Sum(x,y));
119             else printf("%d\n",Max(x,y));
120     }
121     return 0;
122 }

 

posted @ 2016-08-03 18:17  HHshy  阅读(186)  评论(0编辑  收藏  举报