CodeChef Prime Distance On Tree

题面

Problem description.

You are given a tree. If we select 2 distinct nodes uniformly at random, what's the probability that the distance between these 2 nodes is a prime number?

Input

The first line contains a number N: the number of nodes in this tree.
The following N-1 lines contain pairs a[i] and b[i], which means there is an edge with length 1 between a[i] and b[i].

Output

Output a real number denote the probability we want.
You'll get accept if the difference between your answer and standard answer is no more than 10^-6.

Constraints

2N50,000

The input must be a tree.

Example

Input:
5
1 2
2 3
3 4
4 5

Output:
0.5

Explanation

We have C(5, 2) = 10 choices, and these 5 of them have a prime distance:

1-3, 2-4, 3-5: 2

1-4, 2-5: 3

Note that 1 is not a prime number.

题意

给定一颗树,边权均为1,求路径长度为素数的路径概率

题解

同样也是点分治,getdis维护长度为x的路径数量,我们要求的即是

\[a[i]= \sum a[j]*a[i-j] \]

然后我们把素数的提出来就可以了

注意记录fft长度的len1也要清空一下,要不分治就会tle了

#include <bits/stdc++.h>

using namespace std;
const int N = 1e5 + 50;
//fft begin
const double pi = acos(-1.0);
typedef long long ll;
struct cp {
    double r, i;
    cp(double r = 0, double i = 0): r(r), i(i) {}
    cp operator + (const cp &b) {
        return cp(r + b.r, i + b.i);
    }
    cp operator - (const cp &b) {
        return cp(r - b.r, i - b.i);
    }
    cp operator * (const cp &b) {
        return cp(r * b.r - i * b.i, r * b.i + i * b.r);
    }
};
void change(cp a[], int len) {
    for (int i = 1, j = len / 2; i < len - 1; i++) {
        if (i < j) swap(a[i], a[j]);
        int k = len / 2;
        while (j >= k) {
            j -= k;
            k /= 2;
        }
        if (j < k) j += k;
    }
}
void fft(cp a[], int len, int op) {
    change(a, len);
    for (int h = 2; h <= len; h <<= 1) {
        cp wn(cos(-op * 2 * pi / h), sin(-op * 2 * pi / h));
        for (int j = 0; j < len; j += h) {
            cp w(1, 0);
            for (int k = j; k < j + h / 2; k++) {
                cp u = a[k];
                cp t = w * a[k + h / 2];
                a[k] = u + t;
                a[k + h / 2] = u - t;
                w = w * wn;
            }
        }
    }
    if (op == -1) {
        for (int i = 0; i < len; i++) {
            a[i].r /= len;
        }
    }
}
//fft end


int prime[N], notprime[N];
vector<int> G[N];
int S;
int maxx;
int root;
int vis[N];
int msze[N];
int sze[N];
void getroot(int u, int f) {
    sze[u] = 1; msze[u] = 0;
    for (int i = 0; i < G[u].size(); i++) {
        int v = G[u][i];
        if (v == f || vis[v]) continue;
        getroot(v, u);
        sze[u] += sze[v];
        msze[u] = max(msze[u], sze[v]);
    }
    msze[u] = max(msze[u], S - sze[u]);
    if (msze[u] < maxx) {
        maxx = msze[u];
        root = u;
    }
}
int dis[N / 2];
int cnt;
int len1 = 0;
void getdis(int u, int f, int w) {
    dis[w]++;
    len1 = max(len1, w);
    for (int i = 0; i < G[u].size(); i++) {
        int v = G[u][i];
        if (v == f || vis[v]) continue;
        getdis(v, u, w + 1);
    }
}
cp a[N * 4];

ll calc(int u, int w) {
    memset(dis, 0, sizeof(dis));
    len1 = 0;
    getdis(u, 0, w);
    int len = 1;
    while (len < (len1 + 1) * 2) len <<= 1;
    for (int i = 0; i <= len1; i++) {
        a[i] = cp(dis[i], 0);
    }
    for (int i = len1 + 1; i < len; i++) {
        a[i] = cp(0, 0);
    }
    fft(a, len, 1);
    for (int i = 0; i < len; i++) {
        a[i] = a[i] * a[i];
    }
    fft(a, len, -1);
    ll res = 0;
    for (int i = 1; i <= prime[0] && prime[i] < len; i++) {
        res += (ll)(a[prime[i]].r + 0.5);
    }
    return res;
}
ll ans;
void dfs(int u) {
    vis[u] = 1;
    ans += calc(u, 0);
    for (int i = 0; i < G[u].size(); i++) {
        int v = G[u][i];
        if (vis[v]) continue;
        ans -= calc(v, 1);
        S = sze[v];
        root = 0;
        maxx = 1e9;
        getroot(v, u);
        dfs(root);
    }
}
int main() {
    int cnt = 0;
    for (int i = 2; i < N; i++) {
        if (!notprime[i]) prime[++cnt] = i;
        for (int j = 1; j <= cnt && i * prime[j] < N; j++) {
            notprime[i * prime[j]] = 1;
            if (i % prime[j] == 0) break;
        }
    }
    prime[0] = cnt;
    int n; scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    S = n;
    maxx = 1e9;
    root = 0;
    getroot(1, 0);
    dfs(root);
    ll tmp = (ll)n * (ll)(n - 1);
    printf("%.9f\n", (double)ans / tmp);
    return 0;
}
posted @ 2020-01-17 17:44  Artoriax  阅读(161)  评论(0编辑  收藏  举报