bzoj 4827 [Hnoi2017]礼物 FFT
题面
解法
显然,我们可以列出答案的式子:
\[\sum_{i=1}^n(a_i+c-b_i)^2
\]
然后将式子拆开,可得
\[=\sum_{i=1}^n(a_i^2+b_i^2)+nc^2+2c\sum_{i=1}^n(a_i-b_i)-2\sum_{i=1}^na_ib_i
\]
第一项显然是一个定值,第二,三项是一个关于\(c\)的二次函数,求一个最小值即可
所以,我们现在只要使\(\sum a_ib_i\)尽量大即可
我们令第一个手环不动,第二个手环转动\(k\)个,将\(b\)倍长,那么就变成了\(\sum a_ib_{i+k}\)
感觉这样不是很好求,那么不妨将\(a\)倒序,就变成了\(\sum a_{n-i+1}b_{i+k}\)
那么就变成了一个卷积的形式了
然后FFT即可
时间复杂度:\(O(n\ log\ n)\)
代码
#include <bits/stdc++.h>
#define int long long
#define N 1 << 18
using namespace std;
template <typename node> void chkmax(node &x, node y) {x = max(x, y);}
template <typename node> void chkmin(node &x, node y) {x = min(x, y);}
template <typename node> void read(node &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
struct Complex {
double x, y;
Complex (double tx = 0, double ty = 0) {x = tx, y = ty;}
} a[N], b[N];
const double pi = acos(-1);
int rev[N], x[N], y[N];
Complex operator + (Complex a, Complex b) {return (Complex) {a.x + b.x, a.y + b.y};}
Complex operator - (Complex a, Complex b) {return (Complex) {a.x - b.x, a.y - b.y};}
Complex operator * (Complex a, Complex b) {return (Complex) {a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x};}
int calc(int a, int b, int x) {
if (x < 0) return LONG_LONG_MAX;
return a * x * x + b * x;
}
int solve(double a, double b) {
int x = (int)floor(-b / a / 2), y = (int)ceil(-b / a / 2);
int ret = LONG_LONG_MAX;
if (x < 0 || y < 0) return 0;
for (int i = -1; i <= 1; i++)
chkmin(ret, min(calc(a, b, x + i), calc(a, b, y + i)));
return ret;
}
void getrev(int l) {
for (int i = 0; i < (1 << l); i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l - 1);
}
void FFT(Complex *a, int n, int key) {
for (int i = 0; i < n; i++)
if (i > rev[i]) swap(a[i], a[rev[i]]);
for (int i = 1; i < n; i <<= 1) {
Complex wn(cos(pi / i), key * sin(pi / i));
for (int r = i << 1, j = 0; j < n; j += r) {
Complex w(1, 0);
for (int k = 0; k < i; k++, w = w * wn) {
Complex x = a[j + k], y = w * a[i + j + k];
a[j + k] = x + y, a[i + j + k] = x - y;
}
}
}
if (key == -1)
for (int i = 0; i < n; i++) a[i].x /= n;
}
main() {
int n, m; read(n), read(m);
int ans = 0, tmp = 0;
for (int i = 1; i <= n; i++)
read(x[i]), ans += x[i] * x[i];
for (int i = 1; i <= n; i++)
read(y[i]), ans += y[i] * y[i];
for (int i = 1; i <= n; i++) tmp += x[i] - y[i];
ans += solve(n, 2 * tmp);
reverse(x + 1, x + n + 1);
for (int i = 1; i <= n; i++) y[i + n] = y[i];
for (int i = 1; i <= n; i++) a[i].x = x[i];
for (int i = 1; i <= 2 * n; i++) b[i].x = y[i];
int len = 1, l = 0;
while (len <= 3 * n + 2) len <<= 1, l++; getrev(l);
FFT(a, len, 1), FFT(b, len, 1);
for (int i = 0; i < len; i++) a[i] = a[i] * b[i];
FFT(a, len, -1); tmp = 0;
for (int i = 0; i < len; i++)
chkmax(tmp, (int)(a[i].x + 0.5));
ans -= 2 * tmp; cout << ans << "\n";
return 0;
}