BZOJ1036: [ZJOI2008]树的统计Count(树链剖分)

解题思路:

树链剖分裸题,线段树维护区间和和最大值。

代码:

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<algorithm>
  4 #define lll spc<<1
  5 #define rrr spc<<1|1
  6 typedef long long lnt;
  7 using std::max;
  8 using std::swap;
  9 struct trnt{
 10     lnt ms;
 11     lnt val;
 12 }tr[1000000];
 13 struct pnt{
 14     int hd;
 15     int fa;
 16     int dp;
 17     int tp;
 18     lnt vl;
 19     int ind;
 20     int wgt;
 21     int mxs;
 22 }p[1000000];
 23 struct ent{
 24     int twd;
 25     int lst;
 26 }e[1000000];
 27 int cnt;
 28 int n,m;
 29 int dfn;
 30 char cmd[100];
 31 lnt vali[1000000];
 32 void pushup(int spc)
 33 {
 34     tr[spc].val=tr[lll].val+tr[rrr].val;
 35     tr[spc].ms=max(tr[lll].ms,tr[rrr].ms);
 36     return ;
 37 }
 38 void Chg(int spc,lnt x)
 39 {
 40     tr[spc].val=x;
 41     tr[spc].ms=x;
 42     return ;
 43 }
 44 void build(int l,int r,int spc)
 45 {
 46     if(l==r)
 47     {
 48         Chg(spc,vali[l]);
 49         return ;
 50     }
 51     int mid=(l+r)>>1;
 52     build(l,mid,lll);
 53     build(mid+1,r,rrr);
 54     pushup(spc);
 55     return ;
 56 }
 57 void update(int l,int r,int pos,int spc,lnt v)
 58 {
 59     if(l==r)
 60     {
 61         Chg(spc,v);
 62         return ;
 63     }
 64     int mid=(l+r)>>1;
 65     if(pos<=mid)
 66         update(l,mid,pos,lll,v);
 67     else
 68         update(mid+1,r,pos,rrr,v);
 69     pushup(spc);
 70     return ;
 71 }
 72 lnt Max(int l,int r,int ll,int rr,int spc)
 73 {
 74     if(l>rr||ll>r)
 75         return -0x7f3f3f3f7f7f7f7fll;
 76     if(ll<=l&&r<=rr)
 77         return tr[spc].ms;
 78     int mid=(l+r)>>1;
 79     return max(Max(l,mid,ll,rr,lll),Max(mid+1,r,ll,rr,rrr));
 80 }
 81 lnt Sum(int l,int r,int ll,int rr,int spc)
 82 {
 83     if(l>rr||ll>r)
 84         return 0ll;
 85     if(ll<=l&&r<=rr)
 86         return tr[spc].val;
 87     int mid=(l+r)>>1;
 88     return Sum(l,mid,ll,rr,lll)+Sum(mid+1,r,ll,rr,rrr);
 89 }
 90 void ade(int f,int t)
 91 {
 92     cnt++;
 93     e[cnt].twd=t;
 94     e[cnt].lst=p[f].hd;
 95     p[f].hd=cnt;
 96     return ;
 97 }
 98 void Basic_dfs(int x,int f)
 99 {
100     p[x].fa=f;
101     p[x].dp=p[f].dp+1;
102     p[x].wgt=1;
103     int mxs=-1;
104     for(int i=p[x].hd;i;i=e[i].lst)
105     {
106         int to=e[i].twd;
107         if(to==f)
108             continue;
109         Basic_dfs(to,x);
110         p[x].wgt+=p[to].wgt;
111         if(mxs<p[to].wgt)
112         {
113             mxs=p[to].wgt;
114             p[x].mxs=to;
115         }
116     }
117     return ;
118 }
119 void Build_dfs(int x,int top)
120 {
121     if(!x)    
122         return ;
123     p[x].ind=++dfn;
124     vali[dfn]=p[x].vl;
125     p[x].tp=top;
126     Build_dfs(p[x].mxs,top);
127     for(int i=p[x].hd;i;i=e[i].lst)
128     {
129         int to=e[i].twd;
130         if(p[to].ind)
131             continue;
132         Build_dfs(to,to);
133     }
134     return ;
135 }
136 int Lca(int x,int y)
137 {
138     while(p[x].tp!=p[y].tp)
139     {
140         if(p[p[x].tp].dp<p[p[y].tp].dp)
141             swap(x,y);
142         x=p[p[x].tp].fa;
143     }
144     if(p[x].dp>p[y].dp)
145         swap(x,y);
146     return x;
147 }
148 lnt Max1(int x,int y)
149 {
150     int lca=Lca(x,y);
151     lnt ans=-0x3f3f3f3f3f3f3f3fll;
152     while(p[x].tp!=p[lca].tp)
153     {
154         ans=max(ans,Max(1,dfn,p[p[x].tp].ind,p[x].ind,1));
155         x=p[p[x].tp].fa;
156     }
157     ans=max(ans,Max(1,dfn,p[lca].ind,p[x].ind,1));
158     while(p[y].tp!=p[lca].tp)
159     {
160         ans=max(ans,Max(1,dfn,p[p[y].tp].ind,p[y].ind,1));
161         y=p[p[y].tp].fa;
162     }
163     ans=max(ans,Max(1,dfn,p[lca].ind,p[y].ind,1));
164     return ans;
165 }
166 lnt Sum1(int x,int y)
167 {
168     int lca=Lca(x,y);
169     lnt ans=0;
170     while(p[x].tp!=p[lca].tp)
171     {
172         ans+=Sum(1,dfn,p[p[x].tp].ind,p[x].ind,1);
173         x=p[p[x].tp].fa;
174     }
175     ans+=Sum(1,dfn,p[lca].ind,p[x].ind,1);
176     while(p[y].tp!=p[lca].tp)
177     {
178         ans+=Sum(1,dfn,p[p[y].tp].ind,p[y].ind,1);
179         y=p[p[y].tp].fa;
180     }
181     ans+=Sum(1,dfn,p[lca].ind,p[y].ind,1);
182     ans-=p[lca].vl;
183     return ans;
184 }
185 int main()
186 {
187     scanf("%d",&n);
188     for(int i=1;i<n;i++)
189     {
190         int x,y;
191         scanf("%d%d",&x,&y);
192         ade(x,y);
193         ade(y,x);
194     }
195     for(int i=1;i<=n;i++)
196         scanf("%lld",&p[i].vl);
197     Basic_dfs(1,1);
198     Build_dfs(1,1);
199     build(1,dfn,1);
200     scanf("%d",&m);
201     while(m--)
202     {
203         scanf("%s",cmd+1);
204         if(cmd[1]=='C')
205         {
206             int x;
207             lnt y;
208             scanf("%d%lld",&x,&y);
209             p[x].vl=y;
210             update(1,dfn,p[x].ind,1,y);
211         }else if(cmd[2]=='M')
212         {
213             int x,y;
214             scanf("%d%d",&x,&y);
215             printf("%lld\n",Max1(x,y));
216         }else{
217             int x,y;
218             scanf("%d%d",&x,&y);
219             printf("%lld\n",Sum1(x,y));
220         }
221     }
222     return 0;
223 }

 

posted @ 2018-09-29 19:40  Unstoppable728  阅读(177)  评论(0编辑  收藏  举报