BZOJ 3513: [MUTC2013]idiots
忽略不能取同一根的要求。
构造多项式 \(A(x) = \sum x_{a_i}\)
那么 \(A^2(x)\) 就是取两个木棍,组成长度为 \(s\) 的方案。
容斥后得到
\[(\sum x)^2 = \sum x^2 + 2\sum xy
\]
\[\sum xy = \dfrac{(\sum x)^2 - \sum x^2}{2}
\]
这里的 \(x\) 和 \(y\) 就可以看成有序对了。
设多项式 \(B(s)\) 为上述多项式 。
枚举最长边,假设最长边为 \(s\),在其前面的木棍(也就是长度小于它的木棍和长度等于它但不包括它的木棍的个数)有 \(i - 1\) 根,在其后面的木棍有 \(n-i\) 根。
那么符合情况的有序对对数为 \(\sum_{j > s} B(j) - ((i - 1)(n - i + 1) + n - i + C_{n - i}^{2})\)。
然后除一下 \(C_n ^3\) 就可以了。
#include <bits/stdc++.h>
struct Complex {
double r, i;
void clear() { r = i = 0.0; }
Complex(double r = 0, double i = 0): r(r), i(i) {}
Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); }
Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); }
Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); }
};
const double pi = acos(-1.0);
const int N = 5e5 + 7;
int r[N];
void FFT(Complex *a, int n, int pd) {
for (int i = 0; i < n; i++)
if (i < r[i])
std::swap(a[i], a[r[i]]);
for (int mid = 1; mid < n; mid <<= 1) {
Complex wn(cos(pi / mid), pd * sin(pi / mid));
for (int l = mid << 1, j = 0; j < n; j += l) {
Complex w(1.0, 0.0);
for (int k = 0; k < mid; k++, w = w * wn) {
Complex u = a[k + j], v = w * a[k + j + mid];
a[k + j] = u + v;
a[k + j + mid] = u - v;
}
}
}
if (pd < 0)
for (int i = 0; i < n; i++)
a[i] = Complex(a[i].r / n, a[i].i / n);
}
#define ll long long
int len, n, l;
Complex A[N];
int a[N], cntA[N], cntB[N];
ll sum[N];
ll C(int n) {
return 1LL * n * (n - 1) / 2;
}
ll C3(int n) {
return 1LL * n * (n - 1) * (n - 2) / 6;
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
int m = 0;
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
m = std::max(m, a[i]);
cntA[a[i]]++;
cntB[a[i] * 2]++;
}
len = 1;
l = 0;
while (len <= 2 * m) len <<= 1, l++;
for (int i = 0; i < len; i++)
r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1));
for (int i = 0; i < len; i++)
A[i] = Complex((double)cntA[i], 0.0);
FFT(A, len, 1);
for (int i = 0; i < len; i++)
A[i] = A[i] * A[i];
FFT(A, len, -1);
for (int i = 0; i < len; i++)
sum[i] = (ll)((A[i].r - cntB[i]) / 2.0 + 0.5);
for (int i = 1; i < len; i++)
sum[i] += sum[i - 1];
std::sort(a + 1, a + 1 + n);
ll ans = 0;
for (int i = 1; i <= n; i++) {
ans += sum[len - 1] - sum[a[i]];
ll temp = 1LL * (i - 1) * (n - i + 1) + n - i + C(n - i);
ans -= temp;
}
ll all = C3(n);
double res = (double)(ans * 1.0 / all);
printf("%.7f\n", res);
for (int i = 0; i < len; i++) {
A[i].clear();
cntA[i] = cntB[i] = 0;
sum[i] = 0;
}
}
}