CodeChef Prime Distance On Tree
题面
Problem description.
You are given a tree. If we select 2 distinct nodes uniformly at random, what's the probability that the distance between these 2 nodes is a prime number?
Input
The first line contains a number N: the number of nodes in this tree.
The following N-1 lines contain pairs a[i] and b[i], which means there is an edge with length 1 between a[i] and b[i].
Output
Output a real number denote the probability we want.
You'll get accept if the difference between your answer and standard answer is no more than 10^-6.
Constraints
2 ≤ N ≤ 50,000
The input must be a tree.
Example
Input:
5
1 2
2 3
3 4
4 5
Output:
0.5
Explanation
We have C(5, 2) = 10 choices, and these 5 of them have a prime distance:
1-3, 2-4, 3-5: 2
1-4, 2-5: 3
Note that 1 is not a prime number.
题意
给定一颗树,边权均为1,求路径长度为素数的路径概率
题解
同样也是点分治,getdis维护长度为x的路径数量,我们要求的即是
然后我们把素数的提出来就可以了
注意记录fft长度的len1也要清空一下,要不分治就会tle了
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 50;
//fft begin
const double pi = acos(-1.0);
typedef long long ll;
struct cp {
double r, i;
cp(double r = 0, double i = 0): r(r), i(i) {}
cp operator + (const cp &b) {
return cp(r + b.r, i + b.i);
}
cp operator - (const cp &b) {
return cp(r - b.r, i - b.i);
}
cp operator * (const cp &b) {
return cp(r * b.r - i * b.i, r * b.i + i * b.r);
}
};
void change(cp a[], int len) {
for (int i = 1, j = len / 2; i < len - 1; i++) {
if (i < j) swap(a[i], a[j]);
int k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if (j < k) j += k;
}
}
void fft(cp a[], int len, int op) {
change(a, len);
for (int h = 2; h <= len; h <<= 1) {
cp wn(cos(-op * 2 * pi / h), sin(-op * 2 * pi / h));
for (int j = 0; j < len; j += h) {
cp w(1, 0);
for (int k = j; k < j + h / 2; k++) {
cp u = a[k];
cp t = w * a[k + h / 2];
a[k] = u + t;
a[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (op == -1) {
for (int i = 0; i < len; i++) {
a[i].r /= len;
}
}
}
//fft end
int prime[N], notprime[N];
vector<int> G[N];
int S;
int maxx;
int root;
int vis[N];
int msze[N];
int sze[N];
void getroot(int u, int f) {
sze[u] = 1; msze[u] = 0;
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (v == f || vis[v]) continue;
getroot(v, u);
sze[u] += sze[v];
msze[u] = max(msze[u], sze[v]);
}
msze[u] = max(msze[u], S - sze[u]);
if (msze[u] < maxx) {
maxx = msze[u];
root = u;
}
}
int dis[N / 2];
int cnt;
int len1 = 0;
void getdis(int u, int f, int w) {
dis[w]++;
len1 = max(len1, w);
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (v == f || vis[v]) continue;
getdis(v, u, w + 1);
}
}
cp a[N * 4];
ll calc(int u, int w) {
memset(dis, 0, sizeof(dis));
len1 = 0;
getdis(u, 0, w);
int len = 1;
while (len < (len1 + 1) * 2) len <<= 1;
for (int i = 0; i <= len1; i++) {
a[i] = cp(dis[i], 0);
}
for (int i = len1 + 1; i < len; i++) {
a[i] = cp(0, 0);
}
fft(a, len, 1);
for (int i = 0; i < len; i++) {
a[i] = a[i] * a[i];
}
fft(a, len, -1);
ll res = 0;
for (int i = 1; i <= prime[0] && prime[i] < len; i++) {
res += (ll)(a[prime[i]].r + 0.5);
}
return res;
}
ll ans;
void dfs(int u) {
vis[u] = 1;
ans += calc(u, 0);
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (vis[v]) continue;
ans -= calc(v, 1);
S = sze[v];
root = 0;
maxx = 1e9;
getroot(v, u);
dfs(root);
}
}
int main() {
int cnt = 0;
for (int i = 2; i < N; i++) {
if (!notprime[i]) prime[++cnt] = i;
for (int j = 1; j <= cnt && i * prime[j] < N; j++) {
notprime[i * prime[j]] = 1;
if (i % prime[j] == 0) break;
}
}
prime[0] = cnt;
int n; scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
S = n;
maxx = 1e9;
root = 0;
getroot(1, 0);
dfs(root);
ll tmp = (ll)n * (ll)(n - 1);
printf("%.9f\n", (double)ans / tmp);
return 0;
}