bzoj 1036 树的统计Count

题目大意:

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w

我们将以下面的形式来要求你对这棵树完成一些操作:

I. CHANGE u t : 把结点u的权值改为t

II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值

III. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

思路:

树链剖分

衣服都不穿的

搞到线段树里,然后维护维护

背一背代码

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cmath>
  4 #include<cstdlib>
  5 #include<cstring>
  6 #include<algorithm>
  7 #include<vector>
  8 #include<queue>
  9 #define inf 2139062143
 10 #define ll long long
 11 #define MAXN 30101
 12 #define MOD
 13 using namespace std;
 14 inline int read()
 15 {
 16     int x=0,f=1;char ch=getchar();
 17     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
 18     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
 19     return x*f;
 20 }
 21 int n,Cnt,nxt[MAXN*2],fst[MAXN],to[MAXN*2],val[MAXN];
 22 int fa[MAXN],dep[MAXN],bl[MAXN],cnt[MAXN],hsh[MAXN];
 23 struct data{int mx,sum,l,r;}tr[MAXN*3];
 24 void add(int u,int v) {nxt[++Cnt]=fst[u],fst[u]=Cnt,to[Cnt]=v;}
 25 void build(int x)
 26 {
 27     for(int i=fst[x];i;i=nxt[i])
 28     {
 29         if(to[i]==fa[x]) continue;
 30         dep[to[i]]=dep[x]+1;
 31         fa[to[i]]=x;
 32         build(to[i]);
 33         cnt[x]+=cnt[to[i]];
 34     }
 35     cnt[x]++;
 36 }
 37 void Build(int x,int chn)
 38 {
 39     int hvs=0;hsh[x]=++Cnt,bl[x]=chn;
 40     for(int i=fst[x];i;i=nxt[i])
 41         if(fa[x]!=to[i]&&cnt[hvs]<cnt[to[i]]) hvs=to[i];
 42     if(!hvs) return ;
 43     Build(hvs,chn);
 44     for(int i=fst[x];i;i=nxt[i])
 45         if(fa[x]!=to[i]&&hvs!=to[i]) Build(to[i],to[i]);
 46 }
 47 void s_build(int k,int l,int r)
 48 {
 49     tr[k].l=l,tr[k].r=r;
 50     if(l==r) return ;
 51     int mid=(l+r)>>1;
 52     s_build(k<<1,l,mid);
 53     s_build(k<<1|1,mid+1,r);
 54 }
 55 void upd(int k,int pos,int x)
 56 {;
 57     int l=tr[k].l,r=tr[k].r;
 58     if(l==r) {tr[k].mx=tr[k].sum=x;return ;}
 59     int mid=(l+r)>>1;
 60     if(mid>=pos) upd(k<<1,pos,x);
 61     else upd(k<<1|1,pos,x);
 62     tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx);
 63     tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;
 64 }
 65 int q_sum(int k,int a,int b)
 66 {
 67     int l=tr[k].l,r=tr[k].r;
 68     if(l==a&&r==b) return tr[k].sum;
 69     int mid=(l+r)>>1;
 70     if(b<=mid) return q_sum(k<<1,a,b);
 71     if(a>mid) return q_sum(k<<1|1,a,b);
 72     else return q_sum(k<<1,a,mid)+q_sum(k<<1|1,mid+1,b);
 73 }
 74 int q_mx(int k,int a,int b)
 75 {
 76     int l=tr[k].l,r=tr[k].r;
 77     if(l==a&&r==b) return tr[k].mx;
 78     int mid=(l+r)>>1;
 79     if(b<=mid) return q_mx(k<<1,a,b);
 80     if(a>mid) return q_mx(k<<1|1,a,b);
 81     else return max(q_mx(k<<1,a,mid),q_mx(k<<1|1,mid+1,b));
 82 }
 83 int main()
 84 {
 85     n=read();int a,b,res;
 86     for(int i=1;i<n;i++) {a=read(),b=read();add(a,b);add(b,a);}
 87     for(int i=1;i<=n;i++) val[i]=read();fa[1]=1;
 88     build(1);Cnt=0;
 89     Build(1,1);
 90     s_build(1,1,n);
 91     for(int i=1;i<=n;i++) upd(1,hsh[i],val[i]);
 92     int T=read();
 93     char ch[8];
 94     while(T--)
 95     {
 96         scanf("%s",ch);a=read(),b=read();
 97         if(ch[0]=='C') {val[a]=b;upd(1,hsh[a],b);}
 98         else if(ch[1]=='M')
 99         {
100             res=-inf;
101             while(bl[a]!=bl[b])
102             {
103                 if(dep[bl[a]]<dep[bl[b]]) swap(a,b);
104                 res=max(res,q_mx(1,hsh[bl[a]],hsh[a]));
105                 a=fa[bl[a]];
106             }
107             res=max(res,q_mx(1,min(hsh[a],hsh[b]),max(hsh[a],hsh[b])));
108             printf("%d\n",res);
109         }
110         else if(ch[1]=='S')
111         {
112             res=0;
113             while(bl[a]!=bl[b])
114             {
115                 if(dep[bl[a]]<dep[bl[b]]) swap(a,b);
116                 res+=q_sum(1,hsh[bl[a]],hsh[a]);
117                 a=fa[bl[a]];
118             }
119             res+=q_sum(1,min(hsh[a],hsh[b]),max(hsh[a],hsh[b]));
120             printf("%d\n",res);
121         }
122     }
123 }
View Code

 UPD 2018.9.19

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<algorithm>
  4 #include<cmath>
  5 #include<cstring>
  6 #include<cstdlib>
  7 #include<set>
  8 #include<map>
  9 #include<vector>
 10 #include<stack>
 11 #include<queue>
 12 #define ll long long
 13 #define inf 2147383611
 14 #define MAXN 500100
 15 using namespace std;
 16 inline int read()
 17 {
 18     int x=0,f=1;
 19     char ch;ch=getchar();
 20     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
 21     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
 22     return x*f;
 23 }
 24 int n,cnt,nxt[MAXN<<1],fst[MAXN],to[MAXN<<1],val[MAXN];
 25 int fa[MAXN],dep[MAXN],bl[MAXN],sz[MAXN],hsh[MAXN],mx[MAXN<<2],sum[MAXN<<2];
 26 void add(int u,int v) {nxt[++cnt]=fst[u],fst[u]=cnt,to[cnt]=v;}
 27 void dfs(int x)
 28 {
 29     sz[x]=1,dep[x]=dep[fa[x]]+1;
 30     for(int i=fst[x];i;i=nxt[i])
 31         if(to[i]!=fa[x]) {fa[to[i]]=x;dfs(to[i]);sz[x]+=sz[to[i]];}
 32 }
 33 void Dfs(int x,int anc)
 34 {
 35     hsh[x]=++cnt,bl[x]=anc;int hvs=0;
 36     for(int i=fst[x];i;i=nxt[i])
 37         if(dep[to[i]]>dep[x]&&sz[hvs]<sz[to[i]]) hvs=to[i];
 38     if(!hvs) return ;Dfs(hvs,anc);
 39     for(int i=fst[x];i;i=nxt[i])
 40         if(dep[to[i]]>dep[x]&&to[i]!=hvs) {Dfs(to[i],to[i]);}
 41 }
 42 void mdf(int k,int l,int r,int x,int w)
 43 {
 44     if(l==r) {mx[k]=sum[k]=w;return ;}
 45     int mid=(l+r)>>1;
 46     if(x<=mid) mdf(k<<1,l,mid,x,w);
 47     else mdf(k<<1|1,mid+1,r,x,w);
 48     mx[k]=max(mx[k<<1],mx[k<<1|1]),sum[k]=sum[k<<1]+sum[k<<1|1];
 49 }
 50 int querys(int k,int l,int r,int a,int b)
 51 {
 52     if(l==a&&r==b) return sum[k];
 53     int mid=(l+r)>>1;
 54     if(b<=mid) return querys(k<<1,l,mid,a,b);
 55     else if(a>mid) return querys(k<<1|1,mid+1,r,a,b);
 56     else return querys(k<<1,l,mid,a,mid)+querys(k<<1|1,mid+1,r,mid+1,b);
 57 }
 58 int querym(int k,int l,int r,int a,int b)
 59 {
 60     if(l==a&&r==b) return mx[k];
 61     int mid=(l+r)>>1;
 62     if(b<=mid) return querym(k<<1,l,mid,a,b);
 63     else if(a>mid) return querym(k<<1|1,mid+1,r,a,b);
 64     else return max(querym(k<<1,l,mid,a,mid),querym(k<<1|1,mid+1,r,mid+1,b));
 65 }
 66 int main()
 67 {
 68     n=read();int a,b,res,T;char ch[8];
 69     for(int i=1;i<n;i++) {a=read(),b=read();add(a,b);add(b,a);}
 70     dfs(1);cnt=0;Dfs(1,1);
 71     for(int i=1;i<=n;i++) mdf(1,1,n,hsh[i],read());
 72     T=read();
 73     while(T--)
 74     {
 75         scanf("%s",ch);a=read(),b=read();
 76         if(ch[0]=='C') mdf(1,1,n,hsh[a],b);
 77         else if(ch[1]=='M')
 78         {
 79             res=-inf;
 80             while(bl[a]!=bl[b])
 81             {
 82                 if(dep[bl[a]]<dep[bl[b]]) swap(a,b);
 83                 res=max(res,querym(1,1,n,hsh[bl[a]],hsh[a]));
 84                 a=fa[bl[a]];
 85             }
 86             res=max(res,querym(1,1,n,min(hsh[a],hsh[b]),max(hsh[a],hsh[b])));
 87             printf("%d\n",res);
 88         }
 89         else if(ch[1]=='S')
 90         {
 91             res=0;
 92             while(bl[a]!=bl[b])
 93             {
 94                 if(dep[bl[a]]<dep[bl[b]]) swap(a,b);
 95                 res+=querys(1,1,n,hsh[bl[a]],hsh[a]);
 96                 a=fa[bl[a]];
 97             }
 98             res+=querys(1,1,n,min(hsh[a],hsh[b]),max(hsh[a],hsh[b]));
 99             printf("%d\n",res);
100         }
101     }
102 }
View Code

 

posted @ 2017-12-02 13:23  jack_yyc  阅读(159)  评论(0编辑  收藏  举报