【FFT+点分治,树重心】CODECHEF PRIMEDST
通道:https://www.codechef.com/problems/PRIMEDST
题意:一棵树上任取两点,问两点间距离为质数的概率
思路:对于求树上的路径数的题,可以想到要用树分治,我们可以考虑经过一个点的方案数,用fft快速地求出所有方案数,然后减去来自同一颗子树的方案数
代码:
1 #include <cstdio> 2 #include <cmath> 3 #include <cstring> 4 #include <vector> 5 #include <algorithm> 6 7 using namespace std; 8 9 typedef long long ll; 10 11 const int MAX_N = 100007; 12 const int MAX_M = 100007; 13 const double PI = acos(-1.0); 14 15 struct Complex { 16 double r, i; 17 Complex(double _r, double _i) { 18 r = _r; 19 i = _i; 20 } 21 Complex operator + (const Complex &c) { 22 return Complex(c.r + r, c.i + i); 23 } 24 Complex operator - (const Complex &c) { 25 return Complex(r - c.r, i - c.i); 26 } 27 Complex operator * (const Complex &c) { 28 return Complex(c.r * r - c.i * i, c.r * i + c.i * r); 29 } 30 Complex operator / (const int &c) { 31 return Complex(r / c, i / c); 32 } 33 Complex(){} 34 }; 35 namespace FFT { 36 int rev(int id, int len) { 37 int ret = 0; 38 for(int i = 0; (1 << i) < len; ++i) { 39 ret <<= 1; 40 if(id & (1 << i)) ret |= 1; 41 } 42 return ret; 43 } 44 Complex A[MAX_M << 3]; 45 void FFT(Complex *a, int len, int DFT) { 46 for(int i = 0; i < len; ++i) A[rev(i, len)] = a[i]; 47 for(int s = 1; (1 << s) <= len; ++s) { 48 int m = (1 << s); 49 Complex wm = Complex(cos(PI * DFT * 2 / m), sin(PI * DFT * 2 / m)); 50 for(int k = 0; k < len; k += m) { 51 Complex w = Complex(1, 0); 52 for(int j = 0; j < (m >> 1); j++) { 53 Complex t = w * A[k + j + (m >> 1)]; 54 Complex u = A[k + j]; 55 A[k + j] = u + t; 56 A[k + j + (m >> 1)] = u - t; 57 w = w * wm; 58 } 59 } 60 } 61 if(DFT == -1) for(int i = 0; i < len; ++i) A[i] = A[i] / len; 62 for(int i = 0; i < len; i++) a[i] = A[i]; 63 } 64 }; 65 66 struct Node { 67 int v, nxt; 68 Node () { 69 70 } 71 Node (int _v, int _n) { 72 v = _v; 73 nxt = _n; 74 } 75 }; 76 77 int n; 78 int head[MAX_N], edgecnt; 79 Node G[MAX_N << 1]; 80 bool p[MAX_N], del[MAX_N]; 81 int prime[MAX_N], primeCnt; 82 83 void Clear() { 84 edgecnt = 0; 85 memset(head, -1, sizeof head); 86 memset(del, 0, sizeof del); 87 } 88 89 void init() { 90 primeCnt = 0; 91 memset(p, 1, sizeof p); 92 for (int i = 2; i < MAX_N; ++i) { 93 if (p[i]) { 94 prime[primeCnt++] = i; 95 for (int j = 2 * i; j < MAX_N; j += i) 96 p[j] = false; 97 } 98 } 99 } 100 101 void add(int u, int v){ 102 G[edgecnt] = Node(v, head[u]); 103 head[u] = edgecnt++; 104 } 105 106 int son[MAX_N], opt[MAX_N]; 107 vector<int> alln; 108 109 void dfs(int u,int fa) { 110 alln.push_back(u); 111 son[u] = 1; opt[u] = 0; 112 for(int i = head[u]; ~i; i = G[i].nxt) { 113 int v = G[i].v; 114 if(del[v] || v == fa) continue; 115 dfs(v, u); 116 son[u] += son[v]; 117 opt[u] = max(opt[u], son[v]); 118 } 119 } 120 121 int getCenter(int u) { 122 alln.clear(); 123 dfs(u, -1); 124 int mx = 0, ans = -1; 125 int sz = alln.size(); 126 for(int i = 0; i < sz; ++i) { 127 int v = alln[i]; 128 if(ans == -1) ans = v, mx = max(opt[v], sz - son[v]); 129 else { 130 if(max(opt[v], sz - son[v]) < mx) { 131 mx = max(opt[v], sz - son[v]); 132 ans = v; 133 } 134 } 135 } 136 return ans; 137 } 138 139 int tot, D[MAX_N]; 140 void getDist(int u, int fa, int w) { 141 D[tot++] = w; 142 for(int i = head[u]; ~i; i = G[i].nxt) { 143 int v = G[i].v; 144 if(del[v] || v == fa) continue; 145 getDist(v, u, w + 1); 146 } 147 } 148 149 int cnt[MAX_N]; 150 ll res[MAX_M << 3]; 151 Complex A[MAX_M << 3]; 152 153 ll calc() { 154 int up = *max_element(D, D + tot); 155 memset(cnt, 0, sizeof cnt); 156 for (int i = 0; i < tot; ++i) ++cnt[D[i]]; 157 for (int i = 0; i <= up; ++i) A[i] = Complex(cnt[i], 0); 158 int len = 1; 159 while (len <= up) len <<= 1; len <<= 1; 160 for (int i = up + 1; i < len; ++i) A[i] = Complex(0, 0); 161 FFT::FFT(A, len, 1); 162 for (int i = 0; i < len; ++i) A[i] = A[i] * A[i]; 163 FFT::FFT(A, len, -1); 164 for (int i = 0; i < len; ++i) res[i] = (ll)(A[i].r + 0.5); 165 for (int i = 0; i < tot; ++i) --res[D[i] + D[i]]; 166 for (int i = 0; i < primeCnt && prime[i] < len; ++i) res[prime[i]] /= 2; 167 ll ans = 0; 168 for (int i = 0; i < primeCnt && prime[i] < len; ++i) ans += res[prime[i]]; 169 return ans; 170 } 171 172 ll ans; 173 void solve(int u) { 174 u = getCenter(u); 175 tot = 0; getDist(u, -1, 0); 176 ans += calc(); 177 for(int i = head[u]; ~i; i = G[i].nxt) { 178 int v = G[i].v; 179 if(del[v]) continue; 180 tot = 0; 181 getDist(v, u, 1); 182 ans -= calc(); 183 } 184 del[u] = true; 185 for(int i = head[u]; ~i; i =G[i].nxt) { 186 int v = G[i].v; 187 if(del[v]) continue; 188 solve(v); 189 } 190 } 191 192 int main() { 193 init(); 194 while (1 == scanf("%d", &n)) { 195 Clear(); 196 for (int i = 1; i < n; ++i) { 197 int u, v; 198 scanf("%d%d", &u, &v); 199 add(u, v), add(v, u); 200 } 201 ans = 0; 202 solve(1); 203 double an = 1. * ans * 2 / n / (n - 1); 204 printf("%.8f\n", an); 205 } 206 return 0; 207 }