Kattis-aplusb A+B problem
Description
Given \(N\) integers in the range \([-50\, 000, 50\, 000]\), how many ways are there to pick three integers \(a_ i\), \(a_ j\), \(a_ k\), such that \(i\), \(j\), \(k\) are pairwise distinct and \(a_ i + a_ j = a_ k\)? Two ways are different if their ordered triples \((i, j, k)\) of indices are different.
Input
The first line of input consists of a single integer \(N\) (\(1 \leq N \leq 200\, 000\)). The next line consists of \(N\) space-separated integers \(a_1, a_2, \dots , a_ N\).
Output
Output an integer representing the number of ways.
Sample Input 1
4
1 2 3 4
Sample Output 1
4
Sample Input 2
6
1 1 3 3 4 6
Sample Output 2
10
题意
给出n个数字,范围从\([-50\, 000, 50\, 000]\),问有多少组不同的\((i, j, k)\)满足\(a_ i + a_ j = a_ k\)
题解
在多项式中,两项相乘,指数相加,所以我们可以用fft来解决此题.
首先,不考虑i,j,k必须不同,那么对于每个数,我们先加上50000使其变为正数,并把这一项指数对应的系数+1,这样我们对这个数组做一次卷积,遍历一遍原数组即可得到答案.
由于(i,j,k)必须不同,我们要减去自己和自己对答案产生贡献的,开一个数组\(b\),\(b[(x+M)*2]\)也就是在计算答案时要被减去的自己和自己相加的.
由于可能存在0,我们还要考虑0的影响,设0的个数为cnt0个,那么0+0=0和0+ai=ai都被重复计算了,总共多计算了\(2*cnt0*(cnt0-1)+2*cnt0*(n-cnt0)\)次,(第一个乘以2表示等号右边可以选择两个0中任意一个,第二个乘以2表示ai和0可以交换)所以要减去\(2*cnt0*(n-1)\)
代码
#include <bits/stdc++.h>
using namespace std;
const double pi = acos(-1.0);
const int N = 1e6 + 50;
typedef long long ll;
struct cp {
double r, i;
cp(double r = 0, double i = 0): r(r), i(i) {}
cp operator + (const cp &b) {
return cp(r + b.r, i + b.i);
}
cp operator - (const cp &b) {
return cp(r - b.r, i - b.i);
}
cp operator * (const cp &b) {
return cp(r * b.r - i * b.i, r * b.i + i * b.r);
}
};
void change(cp a[], int len) {
for (int i = 1, j = len / 2; i < len - 1; i++) {
if (i < j) swap(a[i], a[j]);
int k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if (j < k) j += k;
}
}
void fft(cp a[], int len, int op) {
change(a, len);
for (int h = 2; h <= len; h <<= 1) {
cp wn(cos(-op * 2 * pi / h), sin(-op * 2 * pi / h));
for (int j = 0; j < len; j += h) {
cp w(1, 0);
for (int k = j; k < j + h / 2; k++) {
cp u = a[k];
cp t = w * a[k + h / 2];
a[k] = u + t;
a[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (op == -1) {
for (int i = 0; i < len; i++) {
a[i].r /= len;
}
}
}
const int M = 50000;
ll num[N];
ll cnt[N];
cp a[N], b[N];
ll ans[N];
int main() {
int n;
scanf("%d", &n);
ll cnt0 = 0;
ll len1 = 0;
for (int i = 0; i < n; i++) {
scanf("%lld", &num[i]);
if (num[i] == 0) cnt0++;
cnt[num[i] + M]++;
len1 = max(len1, num[i] + M + 1);
}
//printf("%lld\n", len1);
ll len = 1;
while (len < 2 * len1) len <<= 1;//必须补两倍
//printf("%lld\n", len);
for (int i = 0; i < len1; i++) {
a[i] = cp(cnt[i], 0);
}
for (int i = len1; i < len; i++) {
a[i] = cp(0, 0);
}
fft(a, len, 1);
for (int i = 0; i < len; i++) {
a[i] = a[i] * a[i];
}
fft(a, len, -1);
for (int i = 0; i < len; i++) {
ans[i] = (ll)(a[i].r + 0.5);
}
for (int i = 0; i < n; i++) {//删去自己和自己的
ans[(num[i] + M) * 2]--;
}
ll res = 0;
for (int i = 0; i < n; i++) {
res += ans[num[i] + M * 2];
}
res -= 2 * cnt0 * (n - 1);
printf("%lld\n", res);
return 0;
}