【FFT+点分治,树重心】CODECHEF PRIMEDST

通道:https://www.codechef.com/problems/PRIMEDST

题意:一棵树上任取两点,问两点间距离为质数的概率

思路:对于求树上的路径数的题,可以想到要用树分治,我们可以考虑经过一个点的方案数,用fft快速地求出所有方案数,然后减去来自同一颗子树的方案数

代码:

  1 #include <cstdio>
  2 #include <cmath>
  3 #include <cstring>
  4 #include <vector>
  5 #include <algorithm>
  6 
  7 using namespace std;
  8 
  9 typedef long long ll;
 10 
 11 const int MAX_N = 100007; 
 12 const int MAX_M = 100007;
 13 const double PI = acos(-1.0);
 14 
 15 struct Complex {
 16     double r, i;
 17     Complex(double _r, double _i) {
 18         r = _r;
 19         i = _i;
 20     }
 21     Complex operator + (const Complex &c) {
 22         return Complex(c.r + r, c.i + i);
 23     }
 24     Complex operator - (const Complex &c) {
 25         return Complex(r - c.r, i - c.i);
 26     }
 27     Complex operator * (const Complex &c) {
 28         return Complex(c.r * r - c.i * i, c.r * i + c.i * r);
 29     }
 30     Complex operator / (const int &c) {
 31         return Complex(r / c, i / c);
 32     }
 33     Complex(){}
 34 };
 35 namespace FFT {
 36     int rev(int id, int len) {
 37         int ret = 0;
 38         for(int i = 0; (1 << i) < len; ++i) {
 39             ret <<= 1;
 40             if(id & (1 << i)) ret |= 1;
 41         }
 42         return ret;
 43     }
 44     Complex A[MAX_M << 3];
 45     void FFT(Complex *a, int len, int DFT) {
 46         for(int i = 0; i < len; ++i) A[rev(i, len)] = a[i];
 47         for(int s = 1; (1 << s) <= len; ++s) {
 48             int m = (1 << s);
 49             Complex wm = Complex(cos(PI * DFT * 2 / m), sin(PI * DFT * 2 / m));
 50             for(int k = 0; k < len; k += m) {
 51                 Complex w = Complex(1, 0);
 52                 for(int j = 0; j < (m >> 1); j++) {
 53                     Complex t = w * A[k + j + (m >> 1)];
 54                     Complex u = A[k + j];
 55                     A[k + j] = u + t;
 56                     A[k + j + (m >> 1)] = u - t;
 57                     w = w * wm;
 58                 }
 59             }
 60         }
 61         if(DFT == -1) for(int i = 0; i < len; ++i) A[i] = A[i] / len;
 62         for(int i = 0; i < len; i++) a[i] = A[i];
 63     }
 64 };
 65 
 66 struct Node {
 67     int v, nxt;
 68     Node () {
 69         
 70     }
 71     Node (int _v, int _n) {
 72         v = _v;
 73         nxt = _n;
 74     }
 75 };
 76 
 77 int n;
 78 int head[MAX_N], edgecnt;
 79 Node G[MAX_N << 1];
 80 bool p[MAX_N], del[MAX_N];
 81 int prime[MAX_N], primeCnt;
 82 
 83 void Clear() {
 84     edgecnt = 0;
 85     memset(head, -1, sizeof head);
 86     memset(del, 0, sizeof del);
 87 }
 88 
 89 void init() {
 90     primeCnt = 0;
 91     memset(p, 1, sizeof p);
 92     for (int i = 2; i < MAX_N; ++i) {
 93         if (p[i]) {
 94             prime[primeCnt++] = i;
 95             for (int j = 2 * i; j < MAX_N; j += i)
 96                 p[j] = false;
 97         }
 98     }
 99 }
100 
101 void add(int u, int v){
102     G[edgecnt] = Node(v, head[u]);
103     head[u] = edgecnt++;
104 }
105 
106 int son[MAX_N], opt[MAX_N];
107 vector<int> alln;
108 
109 void dfs(int u,int fa) {
110     alln.push_back(u);
111     son[u] = 1; opt[u] = 0;
112     for(int i = head[u]; ~i; i = G[i].nxt) {
113         int v = G[i].v;
114         if(del[v] || v == fa) continue;
115         dfs(v, u);
116         son[u] += son[v];
117         opt[u] = max(opt[u], son[v]);
118     }
119 }
120 
121 int getCenter(int u) {
122     alln.clear();
123     dfs(u, -1);
124     int mx = 0, ans = -1;
125     int sz = alln.size();
126     for(int i = 0; i < sz; ++i)  {
127         int v = alln[i];
128         if(ans == -1) ans = v, mx = max(opt[v], sz - son[v]);
129         else  {
130             if(max(opt[v], sz - son[v]) < mx) {
131                 mx = max(opt[v], sz - son[v]);
132                 ans = v;
133             }
134         }
135     }
136     return ans;
137 }
138 
139 int tot, D[MAX_N];
140 void getDist(int u, int fa, int w) {
141     D[tot++] = w;
142     for(int i = head[u]; ~i; i = G[i].nxt) {
143         int v = G[i].v;
144         if(del[v] || v == fa) continue;
145         getDist(v, u, w + 1);
146     } 
147 }
148 
149 int cnt[MAX_N];
150 ll res[MAX_M << 3];
151 Complex A[MAX_M << 3];
152 
153 ll calc() {
154     int up = *max_element(D, D + tot);
155     memset(cnt, 0, sizeof cnt);
156     for (int i = 0; i < tot; ++i) ++cnt[D[i]];
157     for (int i = 0; i <= up; ++i) A[i] = Complex(cnt[i], 0);
158     int len = 1;
159     while (len <= up) len <<= 1; len <<= 1;
160     for (int i = up + 1; i < len; ++i) A[i] = Complex(0, 0);
161     FFT::FFT(A, len, 1);
162     for (int i = 0; i < len; ++i) A[i] = A[i] * A[i];
163     FFT::FFT(A, len, -1);
164     for (int i = 0; i < len; ++i) res[i] = (ll)(A[i].r + 0.5);
165     for (int i = 0; i < tot; ++i) --res[D[i] + D[i]];
166     for (int i = 0; i < primeCnt && prime[i] < len; ++i) res[prime[i]] /= 2; 
167     ll ans = 0;
168     for (int i = 0; i < primeCnt && prime[i] < len; ++i) ans += res[prime[i]];
169     return ans;
170 }
171 
172 ll ans;
173 void solve(int u) {
174     u = getCenter(u);
175     tot = 0; getDist(u, -1, 0);
176     ans += calc();
177     for(int i = head[u]; ~i; i = G[i].nxt) {
178         int v = G[i].v;
179         if(del[v]) continue;
180         tot = 0;
181         getDist(v, u, 1);
182         ans -= calc();
183     }
184     del[u] = true;
185     for(int i = head[u]; ~i; i =G[i].nxt) {
186         int v = G[i].v;
187         if(del[v]) continue;
188         solve(v);
189     }
190 }
191 
192 int main() {
193     init();
194     while (1 == scanf("%d", &n)) {
195         Clear();
196         for (int i = 1; i < n; ++i) {
197             int u, v;
198             scanf("%d%d", &u, &v);
199             add(u, v), add(v, u);
200         }
201         ans = 0;
202         solve(1);
203         double an = 1. * ans * 2 / n / (n - 1);
204         printf("%.8f\n", an);
205     }
206     return 0;
207 }
View Code

 

posted @ 2015-07-24 16:08  mithrilhan  阅读(224)  评论(0编辑  收藏  举报