Aizu 2560 Point Distance FFT
题意:
有一个\(N \times N\)的方阵,第\(x\)行第\(y\)列有\(C_{x,y}\)个点\((0 \leq C_{x,y} \leq 9)\)。
任选两个不同的点,求两点欧几里德距离的均值(或期望)。
然后按距离从小到大输出该距离的平方\(d_i\)和对应的点对数目\(c_i\)。
分析:
首先要化二维为一维,一般来讲给点\((x,y)\)编号\(x \times N+y(0\leq x, y < N)\)。
这里为了区分行和列从而方便计算距离,按照\(x \times 2N + y\)的方式给点编号。
这样对于两个点\((x_1,y_1)\)和\((x_2, y_2)\),对应编号分别为\(id_1 = x_1 \times 2N + y_1\)和\(id_2 = x_2 \times 2N + y_2(id_1 < id_2)\)。
两点之间的行距\(dx=\left \lceil \frac{id_1 - id_2 + N}{2} \right \rceil\)
两点之间的列距\(dy=\left | id_1 - id_2 -dx\times 2N \right |\)
然后用\(FFT\)计算两个多项式:
\[P(x)=\sum C_{i,j}x^{id_{i,j}}
\]
\[Q(x)=\sum C_{i,j}x^{-id_{i,j}}
\]
的乘积。
距离为\(0\)的点对注意去重或者单独计算。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <complex>
using namespace std;
const double PI = acos(-1.0);
typedef long long LL;
typedef complex<double> Complex;
void FFT(Complex* P, int n, int op) {
for(int i = 1, j = 0; i < n - 1; i++) {
for(int s = n; j ^= s >>= 1, ~j&s; );
if(i < j) swap(P[i], P[j]);
}
int log = 0;
while((1 << log) < n) log++;
for(int s = 0; s < log; s++) {
int m = 1 << s;
int m2 = m << 1;
Complex wm(cos(PI / m), sin(PI / m) * op);
for(int i = 0; i < n; i += m2) {
Complex w(1, 0);
for(int j = 0; j < m; j++, w *= wm) {
Complex u = P[i + j];
Complex t = P[i + j + m] * w;
P[i + j] = u + t;
P[i + j + m] = u - t;
}
}
}
if(op == -1) for(int i = 0; i < n; i++) P[i].real(P[i].real() / n);
}
Complex P[2][1 << 22];
int n;
LL cnt[2100000];
double dist(double x, double y) {
return sqrt(x * x + y * y);
}
int main()
{
scanf("%d", &n);
int sum = 0;
int offset = (n - 1) * (n * 2 + 1);
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
int x; scanf("%d", &x);
sum += x;
cnt[0] += (x - 1) * x / 2;
int id = i * 2 * n + j;
P[0][id] = x;
P[1][offset-id] = x;
}
}
int s = 1;
while(s < offset * 2 + 1) s <<= 1;
FFT(P[0], s, 1); FFT(P[1], s, 1);
for(int i = 0; i < s; i++) P[0][i] *= P[1][i];
FFT(P[0], s, -1);
double ans = 0;
for(int i = 1; i <= offset; i++) {
LL t = (LL)(P[0][offset + i].real() + 0.5);
if(!t) continue;
int dx = ((i / n) + 1) >> 1;
int dy = abs(i - dx * n * 2);
ans += dist(dx, dy) * t;
cnt[dx*dx+dy*dy] += t;
}
ans /= (double)sum * (sum - 1) / 2;
printf("%.10f\n", ans);
int num = 0;
int top = (n - 1) * (n - 1) * 2;
for(int i = 0; i <= top && num < 10000; i++) if(cnt[i]) {
printf("%d %lld\n", i, cnt[i]);
num++;
}
return 0;
}