【Foreign】采蘑菇 [点分治]
采蘑菇
Time Limit: 20 Sec Memory Limit: 256 MBDescription
Input
Output
Sample Input
5
1 2 3 2 3
1 2
1 3
2 4
2 5
1 2 3 2 3
1 2
1 3
2 4
2 5
Sample Output
10
9
12
9
11
9
12
9
11
HINT
Main idea
询问从以每个点为起始点时,各条路径上的颜色种类的和。
Solution
我们看到题目,立马想到了O(n^2)的做法,然后从这个做法研究一下本质,我们确定了可以以点分治作为框架。
我们先用点分治来确定一个center(重心)。然后计算跟这个center有关的路径。设现在要统计的是经过center,对x提供贡献的路径。
我们先记录一个记录Sum[x]表示1~i-1子树中 颜色x 第一次出现的位置的那个点 的子树和,然后我们就利用这个Sum来解题。
我们显然可以分两种情况来讨论:
(1)统计center->x出现颜色的贡献:
显然,这时候,对于center->x这一段,直接像O(n^2)做法那样记录一个color表示到目前为止出现的颜色个数,然后加一下即可。再记录一个record表示当前可有的贡献和,一旦出现过一个颜色,那么这个颜色在1~i-1子树上出现第一次以下的点,对于x就不再提供贡献了,record减去Sum[这个颜色],然后这样深搜往下计算即可。
(2)统计center->x没出现过的颜色的贡献:
显然,对于center->x上没出现过的颜色,直接往下深搜,一开始为record为(All - Sum[center]),一旦出现了一个颜色,record则减去这个Sum。同样表示不再提供贡献即可。
我们这样做就可以求出每个子树前缀对于其的贡献了,倒着再做一边即可求出全部的贡献。统计x的时候,顺便统计一下center。可以满足效率,成功AC这道题。
Code
1 #include<iostream>
2 #include<algorithm>
3 #include<cstdio>
4 #include<cstring>
5 #include<cstdlib>
6 #include<cmath>
7 using namespace std;
8
9 const int ONE = 600005;
10 const int INF = 214783640;
11 const int MOD = 1e9+7;
12
13 int n,x,y;
14 int Val[ONE];
15 int next[ONE],first[ONE],go[ONE],tot;
16 int vis[ONE];
17 int Ans[ONE],Sum[ONE];
18 int All;
19
20
21 int get()
22 {
23 int res,Q=1; char c;
24 while( (c=getchar())<48 || c>57)
25 if(c=='-')Q=-1;
26 if(Q) res=c-48;
27 while((c=getchar())>=48 && c<=57)
28 res=res*10+c-48;
29 return res*Q;
30 }
31
32 void Add(int u,int v)
33 {
34 next[++tot]=first[u]; first[u]=tot; go[tot]=v;
35 next[++tot]=first[v]; first[v]=tot; go[tot]=u;
36 }
37
38 namespace Point
39 {
40 int center;
41 int Stack[ONE],top;
42 int total,Max,center_vis[ONE];
43 int num,V[ONE];
44
45 struct power
46 {
47 int size,maxx;
48 }S[ONE];
49
50 void Getsize(int u,int father)
51 {
52 S[u].size=1;
53 S[u].maxx=0;
54 for(int e=first[u];e;e=next[e])
55 {
56 int v=go[e];
57 if(v==father || center_vis[v]) continue;
58 Getsize(v,u);
59 S[u].size += S[v].size;
60 S[u].maxx = max(S[u].maxx,S[v].size);
61 }
62 }
63
64 void Getcenter(int u,int father,int total)
65 {
66 S[u].maxx = max(S[u].maxx,total-S[u].size);
67 if(S[u].maxx < Max)
68 {
69 Max = S[u].maxx;
70 center = u;
71 }
72
73 for(int e=first[u];e;e=next[e])
74 {
75 int v=go[e];
76 if(v==father || center_vis[v]) continue;
77 Getcenter(v,u,total);
78 }
79 }
80
81 void Ad_sum(int u,int father)
82 {
83 if(!vis[Val[u]])
84 {
85 Stack[++top] = Val[u];
86 All += S[u].size; Sum[Val[u]] += S[u].size;
87 }
88 vis[Val[u]]++;
89 for(int e=first[u];e;e=next[e])
90 {
91 int v=go[e];
92 if(v==father || center_vis[v]) continue;
93 Ad_sum(v,u);
94 }
95 vis[Val[u]]--;
96 }
97
98 void Calc_in(int u,int father,int center,int Size,int f_time,int record)
99 {
100 if(!vis[Val[u]]) f_time++, record += Size, record -= Sum[Val[u]];
101 Ans[u] += record; Ans[center]+=f_time;
102 Ans[u] += f_time; vis[Val[u]] ++;
103 for(int e=first[u];e;e=next[e])
104 {
105 int v=go[e];
106 if(v==father || center_vis[v]) continue;
107 Calc_in(v,u,center,Size,f_time,record);
108 }
109 vis[Val[u]] --;
110 }
111
112 void Calc_not(int u,int father,int record)
113 {
114 if(!vis[Val[u]]) record -= Sum[ Val[u] ];
115 Ans[u] += record; vis[Val[u]] ++;
116 for(int e=first[u];e;e=next[e])
117 {
118 int v=go[e];
119 if(v==father || center_vis[v]) continue;
120 Calc_not(v,u,record);
121 }
122 vis[Val[u]] --;
123 }
124
125 void Dfs(int u)
126 {
127 Max = n;
128 Getsize(u,0);
129 Getcenter(u,0,S[u].size);
130 Getsize(center,0);
131 center_vis[center] = 1;
132
133 int num=0; for(int e=first[center];e;e=next[e]) if(!center_vis[go[e]]) V[++num]=go[e];
134
135 for(int i=1;i<=num;i++)
136 {
137 int v=V[i];
138 int Size = S[center].size - S[v].size - 1;
139 vis[Val[center]] = 1;
140 Calc_in(v,center,center, Size,1,All - Sum[Val[center]] + Size);
141 vis[Val[center]] = 0;
142 Ad_sum(v,center);
143 }
144 while(top) Sum[Stack[top--]]=0; All=0;
145
146 for(int i=num;i>=1;i--)
147 {
148 int v=V[i];
149 vis[Val[center]] = 1;
150 Calc_not(v,center, All-Sum[Val[center]]);
151 vis[Val[center]] = 0;
152 Ad_sum(v,center);
153 }
154
155 while(top) Sum[Stack[top--]]=0; All=0;
156 for(int e=first[center];e;e=next[e])
157 {
158 int v=go[e];
159 if(center_vis[v]) continue;
160 Dfs(v);
161 }
162 }
163
164 }
165
166 int main()
167 {
168 n=get();
169 for(int i=1;i<=n;i++) Val[i]=get();
170
171 for(int i=1;i< n;i++)
172 {
173 x=get(); y=get();
174 Add(x,y);
175 }
176
177 Point:: Dfs(1);
178 for(int i=1;i<=n;i++)
179 printf("%d\n",Ans[i]+1);
180 }