BZOJ 4827: [Hnoi2017]礼物
即求一个 \(k\) 和 \(c\),使得 \(\sum\limits_{i=1}^n(x_i+c-y_{(i+k)\bmod n})^2\) 最小
把式子拆开得到 \(nc^2+2(\sum x_i - y_{(i+k)\bmod n})c-2\sum x_iy_{(i+k)\bmod n}\)
当 \(c=\dfrac{\sum x_i-y_{(i+k)\bmod n}}{n}\) 时取值最小
然后把 \(c\) 加到 \(x\) 上,现在就变成了求 \(\max \{\sum\limits_{i=1}^n x_iy_{(i+k)\bmod n}\}\)
这是一个循环卷积的形式,处理方式为将 \(x\) 翻转,将 \(y\) 倍长
FFT之后对于位置 \(i\),\(n - 1 \leq i \leq 2n-2\),\(k=i-(n-1)\)
这题值域较小,最大为 \(nm^2 \leq 10^9\),FFT得到的结果是准确的
如果值域较大可以用求出具体的 \(k\),再重新循环求一遍
#include <bits/stdc++.h>
const int N = 2e5 + 7;
namespace FFT {
const double pi = acos(-1.0);
struct Complex {
double r, i;
Complex() {}
Complex(double r, double i): r(r), i(i) {}
Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); }
Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); }
Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); }
} A[N], B[N];
int n, l, r[N];
void init(int m) {
n = 1, l = 0;
while (n <= m) n <<= 1, l++;
for (int i = 0; i < n; i++)
r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1));
}
void FFT(Complex *a, int pd) {
for (int i = 0; i < n; i++)
if (i < r[i])
std::swap(a[i], a[r[i]]);
for (int mid = 1; mid < n; mid <<= 1) {
Complex wn(cos(pi / mid), pd * sin(pi / mid));
for (int l = mid << 1, j = 0; j < n; j += l) {
Complex w(1.0, 0.0);
for (int k = 0; k < mid; k++, w = w * wn) {
Complex u = a[k + j], v = w * a[k + j + mid];
a[k + j] = u + v;
a[k + j + mid] = u - v;
}
}
}
if (pd == -1)
for (int i = 0; i < n; i++)
a[i] = Complex(a[i].r / n, a[i].i / n);
}
void solve(int *a, int *b, int m, int *sum) {
init(m);
for (int i = 0; i < n; i++)
A[i] = Complex((double)a[i], 0), B[i] = Complex((double)b[i], 0);
FFT(A, 1); FFT(B, 1);
for (int i = 0; i < n; i++)
A[i] = A[i] * B[i];
FFT(A, -1);
for (int i = 0; i < n; i++)
sum[i] = (int)(A[i].r + 0.5);
}
}
int a[N], B[N], sum[N], A[N];
int main() {
int n, m;
scanf("%d%d", &n, &m);
int sumA1 = 0, sumB1 = 0, sumA2 = 0, sumB2 = 0;
for (int i = 0; i < n; i++)
scanf("%d", a + n - i - 1), sumA1 += a[n - i - 1];
for (int i = 0; i < n; i++)
scanf("%d", B + i), sumB2 += B[i] * B[i], sumB1 += B[i], B[n + i] = B[i];
int mn = (sumB1 - sumA1) / n;
int c = mn - 1;
for (int pos = mn - 1; pos <= mn + 1; pos++)
if (n * pos * pos + 2 * pos * (sumA1 - sumB1) < c * c * n + 2 * c * (sumA1 - sumB1))
c = pos;
for (int i = 0; i < n; i++)
A[i] = a[i] + c, sumA2 += A[i] * A[i];
FFT::solve(A, B, 2 * n, sum);
int mx = sum[n - 1];
for (int i = n; i <= 2 * n - 2; i++)
if (mx < sum[i]) mx = sum[i];
int ans = sumA2 + sumB2 - 2 * mx;
printf("%d\n", ans);
return 0;
}