3-idiots
个人理解
第一道\(FFT\)练习题,尽管开始会用模板,但还是心存疑惑
- \(IDFT\)的过程还是不会证
- 自底向上的迭代,取值过程依然模糊
- 线性卷积和循环卷积为什么在\(L > (n + m - 1)\)时相等
- ········
当你学会了\(FFT\),那么这道题剩下的就是一丢丢容斥(枚举不合法的情况)kuangbin聚聚的博客讲的很清楚,就是自己写的时候要仔细!
long long cnt = 0;
for (int i = 0; i < n; ++i) {
cnt += (sum[len - 1] - sum[a[i]]); // 0, 1, ···,len - 1
cnt -= (long long)i * (n - i - 1); // 一个取大一个取小
cnt -= (n - 1); // 一个取本身,一个取其它数
cnt -= (long long)(n - i - 1) * (n - i - 2) / 2;
}
代码
struct Complex {
double real, image;
Complex() {}
Complex(double real, double image) : real(real), image(image) {}
Complex operator + (const Complex &a) const {
return Complex(real + a.real, image + a.image);
}
Complex operator - (const Complex &a) const {
return Complex(real - a.real, image - a.image);
}
Complex operator * (const Complex &a) const {
return Complex(real * a.real - image * a.image, image * a.real + real * a.image);
}
};
int rev(int id, int len) {
int pos = 0;
for (int i = 0; (1 << i) < len; ++i) {
pos <<= 1;
if (id & (1 << i)) pos |= 1;
}
return pos;
}
Complex A[500005];
void FFT(Complex *a, int len, int DFT) {
rep(i, 0, len) A[rev(i, len)] = a[i];
for (int s = 1; (1 << s) <= len; ++s) {
int m = (1 << s);
Complex wm = Complex(cos(DFT * 2 * PI / m), sin(DFT * 2 * PI / m));
for (int i = 0; i < len; i += m) {
Complex w = Complex(1, 0);
for (int j = 0; j < (m >> 1); ++j) {
Complex t = A[i + j];
Complex u = w * A[i + j + (m >> 1)];
A[i + j] = t + u;
A[i + j + (m >> 1)] = t - u;
w = w * wm;
}
}
}
if (DFT == -1) rep(i, 0, len) A[i].real /= len, A[i].image /= len;
rep(i, 0, len) a[i] = A[i];
}
const int N = 500005;
int n;
int cnt[N], b[N];
Complex a[N];
LL sum[N], num[N];
int main()
{
BEGIN() {
mem(cnt, 0);
mem(num, 0);
sc(n);
int ma = 0;
rep(i, 0, n) {
sc(b[i]);
int x = b[i];
cnt[x]++;
num[x + x]--;
ma = max(ma, x);
}
int sa = 0;
while((1 << sa) < (ma + 1)) sa++;
int m = 1 << (sa + 1);
rep(i, 0, m) a[i] = Complex(cnt[i], 0);
FFT(a, m, 1);
rep(i, 0, m) a[i] = a[i] * a[i];
FFT(a, m, -1);
rep(i, 0, m) num[i] += (LL)(a[i].real + 0.5), num[i] /= 2;
rep(i, 1, m) sum[i] = sum[i - 1] + num[i];
sort(b, b + n);
LL cnt = 0;
rep(i, 0, n) {
cnt += (sum[m - 1] - sum[b[i]]);
cnt -= (LL)(n - 1 - i) * i;
cnt -= (n - 1);
cnt -= (LL)(n - 1 - i) * (n - i - 2) / 2;
}
LL tot = (LL)n * (n - 1) * (n - 2) / 6;
printf("%.7f\n", (double)cnt / tot);
}
return 0;
}