Codechef Prime Distance On Tree
FFT第四题!
暑假的时候只会点分,然后合并是暴力合并的...水过去了...
其实两条路径长度的合并就是卷积的过程嘛,每次统计完路径就自卷积一下。
刚开始卷积固定了值域。T了。然后就不偷懒了,每次取最大权值乘二去找值域了。
#include <bits/stdc++.h> const double pi = acos(-1.0); struct Complex { double r, i; void clear() { r = i = 0.0; } Complex(double r = 0, double i = 0): r(r), i(i) {} Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); } Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); } Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); } }; void FFT(Complex *a, int n, int pd, int *r) { for (int i = 0; i < n; i++) if (i < r[i]) std::swap(a[i], a[r[i]]); for (int mid = 1; mid < n; mid <<= 1) { Complex wn(cos(pi / mid), pd * sin(pi / mid)); for (int l = mid << 1, j = 0; j < n; j += l) { Complex w(1.0, 0.0); for (int k = 0; k < mid; k++, w = w * wn) { Complex u = a[k + j], v = w * a[k + j + mid]; a[k + j] = u + v; a[k + j + mid] = u - v; } } } if (pd < 0) for (int i = 0; i < n; i++) a[i] = Complex(a[i].r / n, a[i].i / n); } #define ll long long const int N = 2e5 + 7; int n, sz[N], maxsz[N], root, totsz; std::vector<int> vec[N]; int prime[N], prin; bool vis[N], is[N]; ll cnt[N], ccnt[N]; int dis[N], r[N]; Complex A[N]; int limit, l; void init() { for (int i = 2; i < N; i++) { if (!is[i]) prime[++prin] = i; for (int j = 1; j <= prin && i * prime[j] < N; j++) { is[i * prime[j]] = 1; if (i % prime[j] == 0) break; } } } inline bool chkmax(int &a, int b) { return a < b ? a = b, 1 : 0; } void getroot(int u, int fa) { sz[u] = 1; maxsz[u] = 0; for (int v : vec[u]) { if (v == fa || vis[v]) continue; getroot(v, u); sz[u] += sz[v]; chkmax(maxsz[u], sz[v]); } chkmax(maxsz[u], totsz - sz[u]); if (maxsz[u] < maxsz[root]) root = u; } int f[N], tto, val; void getdis(int u, int fa) { f[++tto] = dis[u]; val = std::max(val, f[tto]); for (int v : vec[u]) { if (vis[v] || v == fa) continue; dis[v] = dis[u] + 1; getdis(v, u); } } void cal(int u, int d, int opt) { tto = 0; dis[u] = d; val = 0; getdis(u, 0); for (int i = 1; i <= tto; i++) ccnt[f[i]]++; limit = 1, l = 0; while (limit <= 2 * val) limit <<= 1, l++; for (int i = 0; i < limit; i++) r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1)); for (int i = 0; i < limit; i++) A[i] = Complex((double)ccnt[i], 0.0); FFT(A, limit, 1, r); for (int i = 0; i < limit; i++) A[i] = A[i] * A[i]; FFT(A, limit, -1, r); for (int i = 1; i < limit; i++) cnt[i] += opt * (ll)(A[i].r + 0.5); for (int i = 1; i <= tto; i++) ccnt[f[i]]--; } void solve(int u) { vis[u] = 1; cal(u, 0, 1); for (int v : vec[u]) { if (vis[v]) continue; cal(v, 1, -1); totsz = sz[v]; root = 0; getroot(v, 0); solve(root); } } int main() { init(); scanf("%d", &n); for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); vec[u].push_back(v); vec[v].push_back(u); } maxsz[root = 0] = n; totsz = n; getroot(1, 0); solve(root); ll ans = 0; for (int i = 1; i <= prin; i++) { ans += cnt[prime[i]]; } ll sum = 1LL * n * (n - 1); printf("%.7f\n", 1.0 * ans / sum); return 0; }