Luogu3723 [AH2017/HNOI2017]礼物
Luogu3723 [AH2017/HNOI2017]礼物
显然可以转化题意为对\(a\)中所有元素加上任意一个整数值。
设加的值为\(x\)。
\[\sum_{i=1}^n (a_i+x-b_i)^2\\
=\sum_{i=1}^n a_i^2+b_i^2+x^2+2a_ix-2b_ix-2a_ib_i\\
=\sum_{i=1}^n a_i^2+\sum_{i=1}^n b_i^2+nx^2+2x\sum_{i=1}^n (a_i-b_i)-2\sum_{i=1}^n a_ib_i
\]
\(=\sum_{i=1}^n a_i^2+\sum_{i=1}^n b_i^2\)是定值,\(nx^2+2x\sum_{i=1}^n (a_i-b_i)\)是二次函数极值。
\(2\sum_{i=1}^n a_ib_i\)可以翻转\(b\)数组。
\[\sum_{i=1}^n a_ib_i=\sum_{i=1}^n a_i b^r_{n-i+1}
\]
卷一卷就好了。
但是还有循环,把\(b\)倍长。
也就是求:
\[\max \sum_{i=1}^n a_i b_{i+k}=\sum_{i=1}^n a_i b^r_{2n-i-k+1}
\]
\(FFT Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#define N 400005
#define D double
#define ll long long
#define getchar() (*p1++)
using namespace std;
const D pi=acos(-1.0);
char buf[1 << 23],obuf[1 << 23],*p1=buf,*O=obuf;
int n,m,s,l,a[N],b[N];
int rev[N];
ll ans=0;
struct virt
{
D x,y;
virt (D a=0.0,D b=0.0)
{
x=a,y=b;
}
virt operator + (virt a)
{
return virt(x+a.x,y+a.y);
}
virt operator - (virt a)
{
return virt(x-a.x,y-a.y);
}
virt operator * (virt a)
{
return virt(x*a.x-y*a.y,x*a.y+y*a.x);
}
}c[N],d[N];
template<typename T>
void read(T &x)
{
x=0;
char c=getchar();
while (c<'0' || c>'9')
c=getchar();
while ('0'<=c && c<='9')
x=(x << 3)+(x << 1)+(c^48),c=getchar();
}
template<typename T>
void write(T x)
{
if (x<0)
*O++='-',x=-x;
if (x>9)
write(x/10);
*O++=x%10+'0';
}
void FFT(virt *a,D t)
{
for (int i=0;i<s;++i)
if (i<rev[i])
swap(a[i],a[rev[i]]);
for (int mid=1;mid<s;mid <<=1)
{
virt gn=virt(cos(pi/mid),t*sin(pi/mid));
for (int j=0;j<s;j+=(mid << 1))
{
virt g=virt(1.0,0.0);
for (int k=0;k<mid;++k,g=g*gn)
{
virt x=a[j+k],y=g*a[j+k+mid];
a[j+k]=x+y;
a[j+k+mid]=x-y;
}
}
}
}
int main()
{
fread(buf,1,1 << 21,stdin);
read(n),read(m);
for (int i=1;i<=n;++i)
read(a[i]);
for (int i=1;i<=n;++i)
read(b[i]);
ll rs=0;
for (int i=1;i<=n;++i)
ans+=a[i]*a[i]+b[i]*b[i],rs+=a[i]-b[i];
rs <<=1;
ll x=-rs/n/2,y=x-1,z=x+1;
ans+=min((ll)n*x*x+rs*x,min((ll)n*y*y+rs*y,(ll)n*z*z+rs*z));
for (int i=1;i<=n;++i)
b[i+n]=b[i];
reverse(b+1,b+(n << 1)+1);
int rn=n+1,rm=n+n+1;
for (int i=0;i<rn;++i)
c[i]=virt(1.0*a[i],0.0);
for (int i=0;i<rm;++i)
d[i]=virt(1.0*b[i],0.0);
s=1,l=0;
while (s<rn+rm)
s <<=1,++l;
for (int i=1;i<s;++i)
rev[i]=(rev[i >> 1] >> 1) | ((i & 1) << l-1);
FFT(c,1.0),FFT(d,1.0);
for (int i=0;i<s;++i)
c[i]=c[i]*d[i];
FFT(c,-1.0);
ll mx=-1919191919191919;
for (int i=n+2;i<=n+n+1;++i)
{
ll t=(ll)(c[i].x/s+0.5);
mx=max(mx,t);
}
ans-=mx << 1;
write(ans),*O++='\n';
fwrite(obuf,O-obuf,1,stdout);
return 0;
}