「luogu2664」树上游戏

又臭又长的点分治+树上乱搞

  1 #include<bits/stdc++.h>
  2 #define ll long long
  3 using namespace std;
  4 const int N=100010,oo=INT_MAX;
  5 int n,color[N],cnt[N],cnt2[N],sub[N],nowchild;
  6 ll sum[N];
  7 vector<int>g[N];
  8 int f[N],siz[N],root,tot_node,sumsiz[N],cursumsiz[N];
  9 int timer,timer2,colorvis[N],curcolorvis[N],isfirst[N];
 10 bool vis[N];
 11 void getroot(int k,int fa){
 12     siz[k]=1,f[k]=0;
 13     int x;
 14     for(int i=0;i<g[k].size();i++){
 15         x=g[k][i];
 16         if(x==fa||vis[x]) continue;
 17         getroot(x,k);
 18         siz[k]+=siz[x],f[k]=max(f[k],siz[x]);
 19     }
 20     f[k]=max(f[k],tot_node-siz[k]);
 21     if(f[root]>f[k]) root=k;
 22     return;
 23 }
 24 void dfs0(int k,int fa){
 25     siz[k]=1;
 26     int x;
 27     for(int i=0;i<g[k].size();i++){
 28         x=g[k][i];
 29         if(x==fa||vis[x]) continue;
 30         dfs0(x,k);
 31         siz[k]+=siz[x];
 32     }
 33     return;
 34 }
 35 void getsum(int k,int fa){
 36     int x;
 37     if(cnt[color[k]]==1&&fa){
 38         if(colorvis[color[k]]!=timer) colorvis[color[k]]=timer,sumsiz[color[k]]=0;
 39         sumsiz[color[k]]+=siz[k];
 40     }
 41     for(int i=0;i<g[k].size();i++){
 42         x=g[k][i];
 43         if(x==fa||vis[x]) continue;
 44         cnt[color[x]]++;
 45         getsum(x,k);
 46     }
 47     cnt[color[k]]--;
 48 }
 49 void fix(int k,int fa){
 50     int x;
 51     if(isfirst[k]==timer){
 52         sub[k]-=sumsiz[color[k]]-cursumsiz[color[k]];
 53     }
 54     for(int i=0;i<g[k].size();i++){
 55         x=g[k][i];
 56         if(x==fa||vis[x]) continue;
 57         fix(x,k);
 58     }
 59     return;
 60 }
 61 void dfs1(int k,int fa,int r){
 62     int x;
 63     if(cnt[color[k]]==1&&fa){
 64         isfirst[k]=timer;
 65         if(curcolorvis[color[k]]!=timer2) curcolorvis[color[k]]=timer2,cursumsiz[color[k]]=0;
 66         sub[k]+=tot_node-siz[nowchild],sub[r]+=siz[k],sub[nowchild]-=siz[k];
 67         cursumsiz[color[k]]+=siz[k];
 68     }
 69     for(int i=0;i<g[k].size();i++){
 70         x=g[k][i];
 71         if(x==fa||vis[x]) continue;
 72         if(!fa) nowchild=x,timer2++;
 73         cnt[color[x]]++;
 74         dfs1(x,k,r);
 75         if(!fa) fix(x,k);
 76     }
 77     cnt[color[k]]--;
 78     return;
 79 }
 80 void dfs2(int k,int fa){
 81     int x;
 82     sum[k]+=sub[k];
 83     for(int i=0;i<g[k].size();i++){
 84         x=g[k][i];
 85         if(x==fa||vis[x]) continue;
 86         sub[x]+=sub[k];
 87         dfs2(x,k);
 88     }
 89     sub[k]=0;
 90     return;
 91 }
 92 void calc(int k){
 93     int x;
 94     timer++;
 95     dfs0(k,0);
 96     cnt[color[k]]++;
 97     getsum(k,0);
 98     cnt[color[k]]++;
 99     dfs1(k,0,k);
100     for(int i=0;i<g[k].size();i++){
101         x=g[k][i];
102         if(vis[x]) continue;
103         sub[x]+=tot_node-siz[x];
104     }
105     sum[k]+=tot_node;
106     dfs2(k,0);
107     return;
108 }
109 void work(int k){
110     vis[k]=1;
111     int x;
112     calc(k);
113     for(int i=0;i<g[k].size();i++){
114         x=g[k][i];
115         if(vis[x]) continue;
116         root=0,tot_node=siz[x];
117         getroot(x,0);
118         work(root);
119     }
120     return;
121 }
122 int main(){
123     int t1,t2;
124     f[0]=oo;
125     scanf("%d",&n);
126     for(int i=1;i<=n;i++) scanf("%d",&color[i]);
127     for(int i=1;i<n;i++){
128         scanf("%d%d",&t1,&t2);
129         g[t1].push_back(t2);g[t2].push_back(t1);
130     }
131     root=0,tot_node=n;
132     getroot(1,0);
133     work(root);
134     for(int i=1;i<=n;i++)printf("%lld\n",sum[i]);
135     return 0;
136 }

 

posted @ 2018-03-08 13:54  Cupcake  阅读(158)  评论(0编辑  收藏  举报