hihocoder#1388 : Periodic Signal
$\min \{\sum_{i = 0}^{n-1}\left(A_i - B_{(i+k)\mod n}\right)^2\}$
把平方项拆开得 $\sum_{i = 0}^{n-1}A_i^2 + \sum_{i = 0}^{n-1}B_{(i+k)\mod n}^2 - 2 \sum_{i = 0}^{n-1}A_iB_{(i+k)\mod n}$
即要最大化 $\sum_{i = 0}^{n-1}A_iB_{(i+k)\mod n}$,循环卷积可以把一个多项式翻转一下就成了普通的卷积形式,然后再把 $B$ 倍长一下,得到 $C_i = \sum_{j=0}^i A^{'}_{j}B_{i-j}$,第 $i$ 项即为 $k$ 为 $i-n+1$ 时 $A^{'}$ 和 $B$ 的卷积。
因为有浮点误差,所以不能直接把FFT后得到的结果拿来算答案,但是 $k$ 的大小是没有问题的,所以把 $k$ 找出来重新算一遍即可。
#include <bits/stdc++.h> #define ll long long const int N = 2e5 + 7; namespace FFT { const double pi = acos(-1.0); struct Complex { double r, i; Complex() {} Complex(double r, double i): r(r), i(i) {} Complex operator + (const Complex &p) const { return Complex(r + p.r, i + p.i); } Complex operator - (const Complex &p) const { return Complex(r - p.r, i - p.i); } Complex operator * (const Complex &p) const { return Complex(r * p.r - i * p.i, r * p.i + i * p.r); } } A[N], B[N]; int n, l, r[N]; void init(int m) { n = 1, l = 0; while (n <= m) n <<= 1, l++; for (int i = 0; i < n; i++) r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1)); } void FFT(Complex *a, int pd) { for (int i = 0; i < n; i++) if (i < r[i]) std::swap(a[i], a[r[i]]); for (int mid = 1; mid < n; mid <<= 1) { Complex wn(cos(pi / mid), pd * sin(pi / mid)); for (int l = mid << 1, j = 0; j < n; j += l) { Complex w(1.0, 0.0); for (int k = 0; k < mid; k++, w = w * wn) { Complex u = a[k + j], v = w * a[k + j + mid]; a[k + j] = u + v; a[k + j + mid] = u - v; } } } if (pd == -1) for (int i = 0; i < n; i++) a[i] = Complex(a[i].r / n, a[i].i / n); } void solve(ll *a, ll *b, int m, ll *sum) { init(m); for (int i = 0; i < n; i++) A[i] = Complex((double)a[i], 0), B[i] = Complex((double)b[i], 0); FFT(A, 1); FFT(B, 1); for (int i = 0; i < n; i++) A[i] = A[i] * B[i]; FFT(A, -1); for (int i = 0; i < n; i++) sum[i] = (ll)(A[i].r + 0.5); } } ll A[N], B[N], sum[N]; void clear(int m) { for (int i = 0; i < m; i++) A[i] = B[i] = sum[i] = 0; } int main() { int T; scanf("%d", &T); while (T--) { int n; scanf("%d", &n); ll sumA = 0, sumB = 0; for (int i = 0; i < n; i++) scanf("%lld", A + n - i - 1), sumA += A[n - i - 1] * A[n - i - 1]; for (int i = 0; i < n; i++) scanf("%lld", B + i), sumB += B[i] * B[i], B[n + i] = B[i]; FFT::solve(A, B, 2 * n, sum); int k = n; for (int i = n + 1; i < 2 * n; i++) if (sum[k] < sum[i]) k = i; k -= n - 1; k %= n; ll ans = sumA + sumB; for (int i = 0; i < n; i++) ans -= 2 * A[n - i - 1] * B[(i + k) % n]; printf("%lld\n", ans); clear(2 * n); } return 0; }
取一个超大模数的NTT直接做,快速乘用了long double变成 $O(1)$ 的。
#include <bits/stdc++.h> #define ll long long const int N = 2e5 + 7; namespace NTT { const ll MOD = 50000000001507329LL; const int G = 3; int n, l, r[N]; void init(int m) { n = 1, l = 0; while (n <= m) n <<= 1, l++; for (int i = 0; i < n; i++) r[i] = r[i >> 1] >> 1 | ((i & 1) << (l - 1)); } ll mul(ll x, ll y) { return (x * y - (ll)(x / (long double)MOD * y + 1e-3) * MOD + MOD) % MOD; } ll qp(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans = mul(ans, a); a = mul(a, a); b >>= 1; } return ans % MOD; } void NTT(ll a[], int pd) { for (int i = 0; i < n; i++) if (i < r[i]) std::swap(a[i], a[r[i]]); for (int mid = 1; mid < n; mid <<= 1) { int l = mid << 1; ll wn = qp(G, (MOD - 1) / l); if (pd == -1) wn = qp(wn, MOD - 2); for (int j = 0; j < n; j+= l) { ll w = 1; for (int k = 0; k < mid; k++, w = mul(w, wn)) { ll u = a[j + k], v = mul(w, a[k + j + mid]); a[k + j] = (u + v) % MOD; a[k + j + mid] = (u - v + MOD) % MOD; } } } if (pd == -1) { ll inv = qp(n, MOD - 2); for (int i = 0; i < n; i++) a[i] = mul(a[i], inv); } } ll solve(ll *A, ll *B, int m) { init(m * 2); NTT(A, 1); NTT(B, 1); for (int i = 0; i < n; i++) A[i] = mul(A[i], B[i]); NTT(A, -1); ll ans = A[m]; for (int i = m; i < m * 2; i++) if (ans < A[i]) ans = A[i]; return ans; } } ll A[N], B[N]; void clear(int n) { for (int i = 0; i < n; i++) A[i] = B[i] = 0; } int main() { int T; scanf("%d", &T); while (T--) { int n; scanf("%d", &n); ll sumA = 0, sumB = 0; for (int i = 0; i < n; i++) { scanf("%lld", A + n - i - 1); sumA += A[n - i - 1] * A[n - i - 1]; } for (int i = 0; i < n; i++) { scanf("%lld", B + i); sumB += B[i] * B[i]; B[n + i] = B[i]; } ll ans = sumA + sumB - 2 * NTT::solve(A, B, n); printf("%lld\n", ans); clear(NTT::n); } return 0; }