The Preliminary Contest for ICPC Asia Shanghai 2019 C. Triple
FFT第三题!
其实就是要求有多少三元组满足两短边之和大于等于第三边。
考虑容斥,就是枚举最长边,另外两个数组里有多少对边之和比它小,然后就是 $n^3$ 减去这个答案。
当 $n \leq 1000$ 时,直接暴力,因为如果继续 FFT 的话复杂度是 $O(slogs)$,$s$ 表示值域,值域都到 $10^5$,$100$ 组吃不消。
比 $1000$ 大就 FFT 做即可。
#include <bits/stdc++.h> struct Complex { double r, i; void clear() { r = i = 0.0; } Complex(double r = 0, double i = 0): r(r), i(i) {} Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); } Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); } Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); } }; const double pi = acos(-1.0); const int N = 5e5 + 7; int n, limit, l, r[N]; void FFT(Complex *a, int n, int pd) { for (int i = 0; i < n; i++) if (i < r[i]) std::swap(a[i], a[r[i]]); for (int mid = 1; mid < n; mid <<= 1) { Complex wn(cos(pi / mid), pd * sin(pi / mid)); for (int l = mid << 1, j = 0; j < n; j += l) { Complex w(1.0, 0.0); for (int k = 0; k < mid; k++, w = w * wn) { Complex u = a[k + j], v = w * a[k + j + mid]; a[k + j] = u + v; a[k + j + mid] = u - v; } } } if (pd < 0) for (int i = 0; i < n; i++) a[i] = Complex(a[i].r / n, a[i].i / n); } #define ll long long int A[N], B[N], C[N]; ll cntA[N], cntB[N], cntC[N]; void init(int val) { for (int i = 0; i <= val; i++) cntA[i] = cntB[i] = cntC[i] = 0; } void solve1() { int val = 2 * std::max(A[n], std::max(B[n], C[n])) + 1; for (int i = 1; i <= val; i++) cntA[i] += cntA[i - 1], cntB[i] += cntB[i - 1], cntC[i] += cntC[i - 1]; ll ans = 0; for (int i = 1; i <= n; i++) for (int j = 1; j <= n; j++) { int cur = A[i] + B[j]; ans += cntC[val] - cntC[cur]; cur = A[i] + C[j]; ans += cntB[val] - cntB[cur]; cur = B[i] + C[j]; ans += cntA[val] - cntA[cur]; } ans = 1LL * n * n * n - ans; assert(ans >= 0); printf("%lld\n", ans); } Complex a[N], b[N], c[N], res[N]; void solve2() { limit = 1, l = 0; int val = 2 * std::max(A[n], std::max(B[n], C[n])) + 1; while (limit <= val) limit <<= 1, l++; for (int i = 0; i < limit; i++) r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1)); for (int i = 0; i < limit; i++) a[i] = Complex((double)cntA[i], 0.0), b[i] = Complex((double)cntB[i], 0.0), c[i] = Complex((double)cntC[i], 0.0); for (int i = 1; i <= val; i++) cntA[i] += cntA[i - 1], cntB[i] += cntB[i - 1], cntC[i] += cntC[i - 1]; FFT(a, limit, 1); FFT(b, limit, 1); FFT(c, limit, 1); for (int i = 0; i < limit; i++) res[i] = a[i] * b[i]; FFT(res, limit, -1); ll ans = 0; for (int i = 0; i < limit; i++) { ll temp = (ll)(res[i].r + 0.5); ans += (cntC[val] - cntC[i]) * temp; } for (int i = 0; i < limit; i++) res[i] = b[i] * c[i]; FFT(res, limit, -1); for (int i = 0; i < limit; i++) { ll temp = (ll)(res[i].r + 0.5); ans += (cntA[val] - cntA[i]) * temp; } for (int i = 0; i < limit; i++) res[i] = a[i] * c[i]; FFT(res, limit, -1); for (int i = 0; i < limit; i++) { ll temp = (ll)(res[i].r + 0.5); ans += (cntB[val] - cntB[i]) * temp; } ans = 1LL * n * n * n - ans; assert(ans >= 0); printf("%lld\n", ans); } int main() { int T; scanf("%d", &T); for (int kase = 1; kase <= T; kase++) { scanf("%d", &n); int x = 0; for (int i = 1; i <= n; i++) scanf("%d", A + i), x = std::max(x, A[i]); for (int i = 1; i <= n; i++) scanf("%d", B + i), x = std::max(x, B[i]); for (int i = 1; i <= n; i++) scanf("%d", C + i), x = std::max(x, C[i]); init(N - 1); std::sort(A + 1, A + 1 + n); std::sort(B + 1, B + 1 + n); std::sort(C + 1, C + 1 + n); for (int i = 1; i <= n; i++) cntA[A[i]]++, cntB[B[i]]++, cntC[C[i]]++; printf("Case #%d: ", kase); if (n <= 1000) solve1(); else solve2(); } return 0; }