SPOJ - Triple Sums

 

【传送门】

FFT第一题!

构造多项式 $A(x) = \sum x ^ {s_i}$。

不考虑题目中 $i < j < k$ 的条件,那么 $A^3(x)$ 每一项对应的系数就是答案了。

 考虑容斥。

$$(\sum x)^3 = \sum x^3 + 3 \sum x^2 y + 6\sum xyz$$

$$(\sum x^2) (\sum x)= \sum x^3 + \sum x^2 y$$

所以 $$\sum xyz = \dfrac{(\sum x)^3 - 3 (\sum x^2)(\sum x) + 2 \sum x^3}{6}$$ 

#include <bits/stdc++.h>

struct Complex {
    double r, i;
    Complex(double r = 0.0, double i = 0.0): r(r), i(i) {}
    Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); }
    Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); }
    Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); }
};

const double pi = acos(-1.0);
const int N = 2e5 + 7;
int n, limit, r[N], l;
int v[N], A[N], B[N], C[N];
Complex a[N], b[N], c[N];

void FFT(Complex *a, int pd) {
    for (int i = 0; i < limit; i++)
        if (i < r[i])
            std::swap(a[i], a[r[i]]);
    for (int mid = 1; mid < limit; mid <<= 1) {
        Complex wn = Complex(cos(pi / mid), pd * sin(pi / mid));
        for (int l = mid << 1, j = 0; j < limit; j += l) {
            Complex w = Complex(1.0, 0.0);
            for (int k = 0; k < mid; k++, w = w * wn) {
                Complex u = a[k + j], v = w * a[k + j + mid];
                a[k + j] = u + v;
                a[k + j + mid] = u - v;
            }
        }
    }
    if (pd == -1)
        for (int i = 0; i < limit; i++)
            a[i] = Complex(a[i].r / limit, a[i].i / limit);
}

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) {
        int x;
        scanf("%d", &x);
        x += 20000;
        A[x]++;
        B[x * 2]++;
        C[x * 3]++;
    }
    for (int i = 0; i <= 40000; i++)
        a[i] = Complex((double)A[i], 0.0);
    for (int i = 0; i <= 80000; i++)
        b[i] = Complex((double)B[i], 0.0);
    limit = 1;
    while (limit <= 40000 + 80000)
        limit <<= 1, l++;
    for (int i = 0; i < limit; i++)
        r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1));
    FFT(a, 1);
    FFT(b, 1);
    for (int i = 0; i < limit; i++)
        b[i] = b[i] * a[i];
    for (int i = 0; i < limit; i++)
        a[i] = a[i] * a[i] * a[i];
    FFT(a, -1);
    FFT(b, -1);
    for (int i = 0; i <= 120000; i++) {
        long long ans = (long long)((a[i].r - 3.0 * b[i].r + 2.0 * C[i]) / 6.0 + 0.5);
        if (ans > 0)
            printf("%d : %lld\n", i - 60000, ans);
    }
    return 0;
}
View Code

 

posted @ 2019-11-17 18:54  Mrzdtz220  阅读(142)  评论(0编辑  收藏  举报