Kattis-aplusb A+B problem

Description

Given \(N\) integers in the range \([-50\, 000, 50\, 000]\), how many ways are there to pick three integers \(a_ i\), \(a_ j\), \(a_ k\), such that \(i\), \(j\), \(k\) are pairwise distinct and \(a_ i + a_ j = a_ k\)? Two ways are different if their ordered triples \((i, j, k)\) of indices are different.

Input

The first line of input consists of a single integer \(N\) (\(1 \leq N \leq 200\, 000\)). The next line consists of \(N\) space-separated integers \(a_1, a_2, \dots , a_ N\).

Output

Output an integer representing the number of ways.

Sample Input 1

4
1 2 3 4

Sample Output 1

4

Sample Input 2

6
1 1 3 3 4 6

Sample Output 2

10

题意

给出n个数字,范围从\([-50\, 000, 50\, 000]\),问有多少组不同的\((i, j, k)\)满足\(a_ i + a_ j = a_ k\)

题解

在多项式中,两项相乘,指数相加,所以我们可以用fft来解决此题.

首先,不考虑i,j,k必须不同,那么对于每个数,我们先加上50000使其变为正数,并把这一项指数对应的系数+1,这样我们对这个数组做一次卷积,遍历一遍原数组即可得到答案.

由于(i,j,k)必须不同,我们要减去自己和自己对答案产生贡献的,开一个数组\(b\),\(b[(x+M)*2]\)也就是在计算答案时要被减去的自己和自己相加的.

由于可能存在0,我们还要考虑0的影响,设0的个数为cnt0个,那么0+0=0和0+ai=ai都被重复计算了,总共多计算了\(2*cnt0*(cnt0-1)+2*cnt0*(n-cnt0)\)次,(第一个乘以2表示等号右边可以选择两个0中任意一个,第二个乘以2表示ai和0可以交换)所以要减去\(2*cnt0*(n-1)\)

代码

#include <bits/stdc++.h>
using namespace std;
const double pi = acos(-1.0);
const int N = 1e6 + 50;
typedef long long ll;

struct cp {
    double r, i;
    cp(double r = 0, double i = 0): r(r), i(i) {}
    cp operator + (const cp &b) {
        return cp(r + b.r, i + b.i);
    }
    cp operator - (const cp &b) {
        return cp(r - b.r, i - b.i);
    }
    cp operator * (const cp &b) {
        return cp(r * b.r - i * b.i, r * b.i + i * b.r);
    }
};
void change(cp a[], int len) {
    for (int i = 1, j = len / 2; i < len - 1; i++) {
        if (i < j) swap(a[i], a[j]);
        int k = len / 2;
        while (j >= k) {
            j -= k;
            k /= 2;
        }
        if (j < k) j += k;
    }
}
void fft(cp a[], int len, int op) {
    change(a, len);
    for (int h = 2; h <= len; h <<= 1) {
        cp wn(cos(-op * 2 * pi / h), sin(-op * 2 * pi / h));
        for (int j = 0; j < len; j += h) {
            cp w(1, 0);
            for (int k = j; k < j + h / 2; k++) {
                cp u = a[k];
                cp t = w * a[k + h / 2];
                a[k] = u + t;
                a[k + h / 2] = u - t;
                w = w * wn;
            }
        }
    }
    if (op == -1) {
        for (int i = 0; i < len; i++) {
            a[i].r /= len;
        }
    }
}

const int M = 50000;

ll num[N];
ll cnt[N];

cp a[N], b[N];

ll ans[N];

int main() {
    int n;
    scanf("%d", &n);
    ll cnt0 = 0;
    ll len1 = 0;
    for (int i = 0; i < n; i++) {
        scanf("%lld", &num[i]);
        if (num[i] == 0) cnt0++;
        cnt[num[i] + M]++;
        len1 = max(len1, num[i] + M + 1);
    }
    //printf("%lld\n", len1);
    ll len = 1;
    while (len < 2 * len1) len <<= 1;//必须补两倍
    //printf("%lld\n", len);
    for (int i = 0; i < len1; i++) {
        a[i] = cp(cnt[i], 0);
    }
    for (int i = len1; i < len; i++) {
        a[i] = cp(0, 0);
    }
    fft(a, len, 1);
    for (int i = 0; i < len; i++) {
        a[i] = a[i] * a[i];
    }
    fft(a, len, -1);
    for (int i = 0; i < len; i++) {
        ans[i] = (ll)(a[i].r + 0.5);
    }
    for (int i = 0; i < n; i++) {//删去自己和自己的
        ans[(num[i] + M) * 2]--;
    }
    ll res = 0;
    for (int i = 0; i < n; i++) {
        res += ans[num[i] + M * 2];
    }
    res -= 2 * cnt0 * (n - 1);
    printf("%lld\n", res);
    return 0;
}
posted @ 2020-01-13 11:41  Artoriax  阅读(285)  评论(0编辑  收藏  举报