The Preliminary Contest for ICPC Asia Shanghai 2019 C. Triple

 

【传送门】

FFT第三题!

其实就是要求有多少三元组满足两短边之和大于等于第三边。

考虑容斥,就是枚举最长边,另外两个数组里有多少对边之和比它小,然后就是 $n^3$ 减去这个答案。

当 $n \leq 1000$ 时,直接暴力,因为如果继续 FFT 的话复杂度是 $O(slogs)$,$s$ 表示值域,值域都到 $10^5$,$100$ 组吃不消。

比 $1000$ 大就 FFT 做即可。

#include <bits/stdc++.h>

struct Complex {
    double r, i;
    void clear() { r = i = 0.0; }
    Complex(double r = 0, double i = 0): r(r), i(i) {}
    Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); }
    Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); }
    Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); }
};

const double pi = acos(-1.0);
const int N = 5e5 + 7;
int n, limit, l, r[N];

void FFT(Complex *a, int n, int pd) {
    for (int i = 0; i < n; i++)
        if (i < r[i])
            std::swap(a[i], a[r[i]]);
    for (int mid = 1; mid < n; mid <<= 1) {
        Complex wn(cos(pi / mid), pd * sin(pi / mid));
        for (int l = mid << 1, j = 0; j < n; j += l) {
            Complex w(1.0, 0.0);
            for (int k = 0; k < mid; k++, w = w * wn) {
                Complex u = a[k + j], v = w * a[k + j + mid];
                a[k + j] = u + v;
                a[k + j + mid] = u - v;
            }
        }
    }
    if (pd < 0)
        for (int i = 0; i < n; i++)
            a[i] = Complex(a[i].r / n, a[i].i / n);
}

#define ll long long

int A[N], B[N], C[N];
ll cntA[N], cntB[N], cntC[N];

void init(int val) {
    for (int i = 0; i <= val; i++)
        cntA[i] = cntB[i] = cntC[i] = 0;
}

void solve1() {
    int val = 2 * std::max(A[n], std::max(B[n], C[n])) + 1;
    for (int i = 1; i <= val; i++)
        cntA[i] += cntA[i - 1], cntB[i] += cntB[i - 1], cntC[i] += cntC[i - 1];
    ll ans = 0;
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++) {
            int cur = A[i] + B[j];
            ans += cntC[val] - cntC[cur];
            cur = A[i] + C[j];
            ans += cntB[val] - cntB[cur];
            cur = B[i] + C[j];
            ans += cntA[val] - cntA[cur];
        }
    ans = 1LL * n * n * n - ans;
    assert(ans >= 0);
    printf("%lld\n", ans);
}

Complex a[N], b[N], c[N], res[N];

void solve2() {
    limit = 1, l = 0;
    int val = 2 * std::max(A[n], std::max(B[n], C[n])) + 1;
    while (limit <= val) limit <<= 1, l++;
    for (int i = 0; i < limit; i++)
        r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1));
    for (int i = 0; i < limit; i++)
        a[i] = Complex((double)cntA[i], 0.0), b[i] = Complex((double)cntB[i], 0.0), c[i] = Complex((double)cntC[i], 0.0);
    for (int i = 1; i <= val; i++)
        cntA[i] += cntA[i - 1], cntB[i] += cntB[i - 1], cntC[i] += cntC[i - 1];
    FFT(a, limit, 1); FFT(b, limit, 1); FFT(c, limit, 1);
    for (int i = 0; i < limit; i++)
        res[i] = a[i] * b[i];
    FFT(res, limit, -1);
    ll ans = 0;
    for (int i = 0; i < limit; i++) {
        ll temp = (ll)(res[i].r + 0.5);
        ans += (cntC[val] - cntC[i]) * temp;
    }
    for (int i = 0; i < limit; i++)
        res[i] = b[i] * c[i];
    FFT(res, limit, -1);
    for (int i = 0; i < limit; i++) {
        ll temp = (ll)(res[i].r + 0.5);
        ans += (cntA[val] - cntA[i]) * temp;
    }
    for (int i = 0; i < limit; i++)
        res[i] = a[i] * c[i];
    FFT(res, limit, -1);
    for (int i = 0; i < limit; i++) {
        ll temp = (ll)(res[i].r + 0.5);
        ans += (cntB[val] - cntB[i]) * temp;
    }
    ans = 1LL * n * n * n - ans;
    assert(ans >= 0);
    printf("%lld\n", ans);
}

int main() {
    int T;
    scanf("%d", &T);
    for (int kase = 1; kase <= T; kase++) {
        scanf("%d", &n);
        int x = 0;
        for (int i = 1; i <= n; i++)
            scanf("%d", A + i), x = std::max(x, A[i]);
        for (int i = 1; i <= n; i++)
            scanf("%d", B + i), x = std::max(x, B[i]);
        for (int i = 1; i <= n; i++)
            scanf("%d", C + i), x = std::max(x, C[i]);
        init(N - 1);
        std::sort(A + 1, A + 1 + n);
        std::sort(B + 1, B + 1 + n);
        std::sort(C + 1, C + 1 + n);
        for (int i = 1; i <= n; i++)
            cntA[A[i]]++, cntB[B[i]]++, cntC[C[i]]++;
        printf("Case #%d: ", kase);
        if (n <= 1000) solve1();
        else solve2();
    }
    return 0;
}
View Code

 

posted @ 2019-11-18 20:39  Mrzdtz220  阅读(123)  评论(0编辑  收藏  举报