洛谷 [P3723] 礼物
FFT
https://www.luogu.org/problemnew/solution/P3723
重点在于构造卷积的形式
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <algorithm>
using namespace std;
const int MAXN = 400005;
const double PI = acos(-1);
int init() {
int rv = 0, fh = 1;
char c = getchar();
while(c < '0' || c > '9') {
if(c == '-') fh = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
rv = (rv<<1) + (rv<<3) + c - '0';
c = getchar();
}
return fh * rv;
}
struct Complex {
double x, y;
Complex(double xx = 0.0, double yy = 0.0) {
x = xx; y = yy;
}
Complex operator + (const Complex &u) const {
return Complex(x + u.x, y + u.y);
}
Complex operator - (const Complex &u) const {
return Complex(x - u.x, y - u.y);
}
Complex operator * (const Complex &u) const {
return Complex(x * u.x - y * u.y, x * u.y + y * u.x);
}
}a[MAXN], b[MAXN];
int n, m, lim = 1, limcnt, rev[MAXN], num1[MAXN], num2[MAXN], ttt;
long long ans = 0, c;
void fft(Complex a[], int opt) {
for(int i = 0; i <= lim; i++) {
if(i < rev[i]) swap(a[i], a[rev[i]]);
}
for(int mid = 1; mid < lim; mid <<= 1) {
Complex wn = Complex(cos(PI / mid), opt * sin(PI / mid));
for(int R = mid << 1, j = 0; j < lim; j += R) {
Complex w = Complex(1.0, 0.0);
for(int k = 0; k < mid; k++) {
Complex x = a[j + k], y = w * a[j + mid + k];
a[j + k] = x + y;
a[j + mid + k] = x - y;
w = w * wn;
}
}
}
if(opt == -1) {
for(int i = 0; i <= lim; i++) {
a[i].x /= lim;
}
}
}
int main() {
n = init(); m = init();
for(int i = 1; i <= n; i++) {
num1[i] = init();
a[n - i].x = num1[i];
ans += num1[i] * num1[i];
ttt += num1[i];
}
for(int i = 0; i < n; i++) {
num2[i] = init();
b[i].x = b[i + n].x = num2[i];
ans += num2[i] * num2[i];
ttt -= num2[i];
}
double t = -(double)ttt / n;
if(t > 0.0) c = (int)(t + 0.5);
else c = (int) (t - 0.5);
ans += n * c * c + 2 * c * ttt;
while(lim <= n * 3) lim <<= 1, limcnt++;
for(int i = 0; i <= lim; i++)
rev[i] = (rev[i>>1]>>1) | ((i&1) << (limcnt - 1));
fft(a, 1); fft(b, 1);
for(int i = 0; i <= lim; i++) {
a[i] = a[i] * b[i];
}
fft(a, -1);
int tmp = 0;
for(int i = n - 1; i <= 2 * n - 1; i++) {
tmp = max(tmp, (int)(a[i].x + 0.01));
}
ans -= 2 * tmp;
cout << ans << endl;
return 0;
}