Luogu 3723 [AH2017/HNOI2017]礼物
BZOJ 4827
$$\sum_{i = 1}^{n}(x_i - y_i + c)^2 = \sum_{i = 1}^{n}(x_i^2 + y_i^2 + c^2 - 2 * x_iy_i + 2c * x_i - 2c * y_i) = \sum_{i = 1}^{n}x_i^2 + \sum_{i = 1}^{n}y_i^2 + nc^2 + (2\sum_{i = 1}^{n}(x_i -y_i))c - 2 * \sum_{i = 1}^{n}x_iy_i$$
发现第一项和第二项是一个定值,而第三项和第四项构成了一个开口向上的二次函数,当$c$最靠近对称轴的时候最小,唯一要处理的是最后一项$\sum_{i = 1}^{n}x_iy_i$。
把$y$序列翻转一下,变成$\sum_{i = 1}^{n}x_iy_{n - i + 1}$,这是一个卷积的形式,可以使用$NTT$来加速。
而题目中要求可以旋转一个序列,那么把$x$序列倍长然后当作多项式和翻转后的$y$序列乘起来。
这一项就是在乘起来之后的第$n + 1$项到$2 * n$项中取个最大值。
时间复杂度$O(nlogn)$。
Code:
#include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int N = 3e5 + 5; const ll P = 998244353LL; const ll inf = 1LL << 60; int n, m, lim = 1, pos[N]; ll a[N], b[N]; template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') op = -1; for (; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } template <typename T> inline void swap(T &x, T &y) { T t = x; x = y; y = t; } template <typename T> inline void chkMin(T &x, T y) { if(y < x) x = y; } inline ll fpow(ll x, ll y) { ll res = 1LL; for (; y > 0; y >>= 1) { if (y & 1) res = res * x % P; x = x * x % P; } return res; } inline void prework() { int l = 0; for (; lim <= 3 * n; ++l, lim <<= 1); for (int i = 0; i < lim; i++) pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1)); } inline void ntt(ll *c, int opt) { for (int i = 0; i < lim; i++) if (i < pos[i]) swap(c[i], c[pos[i]]); for (int i = 1; i < lim; i <<= 1) { ll wn = fpow(3, (P - 1) / (i << 1)); if(opt == -1) wn = fpow(wn, P - 2); for (int len = i << 1, j = 0; j < lim; j += len) { ll w = 1LL; for (int k = 0; k < i; k++, w = w * wn % P) { ll x = c[j + k], y = w * c[j + k + i] % P; c[j + k] = (x + y) % P, c[j + k + i] = (x - y + P) % P; } } } if (opt == -1) { ll inv = fpow(lim, P - 2); for (int i = 0; i < lim; i++) c[i] = c[i] * inv % P; } } int main() { read(n), read(m); ll suma = 0LL, sqra = 0LL, sumb = 0LL, sqrb = 0LL; for (int i = 0; i < n; i++) { read(a[i]); suma += a[i], sqra += a[i] * a[i]; } for (int i = 0; i < n; i++) { read(b[i]); sumb += b[i], sqrb += b[i] * b[i]; } for (int i = 0; i < n; i++) a[i + n] = a[i]; for (int i = 0; i < (n / 2); i++) swap(b[i], b[n - i - 1]); prework(); ntt(a, 1), ntt(b, 1); for (int i = 0; i < lim; i++) a[i] = a[i] * b[i] % P; ntt(a, -1); /* for (int i = 0; i < lim; i++) printf("%lld%c", a[i], i == (lim - 1) ? '\n' : ' '); */ ll ans = inf; for (int i = 0; i < n; i++) { for (int j = -m; j <= m; j++) { ll res = sqra + sqrb + 1LL * n * j * j; res += 2LL * j * (suma - sumb) - 2LL * a[i + n - 1]; chkMin(ans, res); } } printf("%lld\n", ans); return 0; }