洛谷P3273/LOJ2020/BZOJ4827[AHOI2017/HNOI2017]礼物(FFT)
因为$a_i+x,b_i+y$和$a_i+(x-y),b_i$是等价的,所以只需讨论给$a_i$加上一个$[-m,m]$的数即可。
对于题目中的式子:
$$\begin{aligned}\sum\limits^n_{i=1}(a_i+x-b_i)^2&=\sum\limits^n_{i=1}(a_i^2+b_i^2+x^2+2a_ix-2b_ix-2a_ib_i)\\&=\sum\limits^n_{i=1}a_i^2+\sum\limits^n_{i=1}b_i^2+2(\sum\limits^n_{i=1}a_i-\sum\limits^n_{i=1}b_i)x-2\sum\limits^n_{i=1}a_ib_i\end{aligned}$$
可以发现前面的项都是定值,只要求出$\sum\limits^n_{i=1}a_ib_i$的最大值就可以了。
把$a$反过来,变成$\sum\limits^n_{i=1}a_{n-i+1}b_i$,这是一个典型的卷积式,把$a$倍长,用FFT求卷积,结果里的n+1~2n项的系数就是n个结果,取出最大值,暴力枚举一下x,求出最小值即可。
#include<cstdio> #include<cmath> typedef long long ll; const double pi=acos(-1.0); const int N=300000; char rB[1<<21],*rS,*rT; inline char gc(){return rS==rT&&(rT=(rS=rB)+fread(rB,1,1<<21,stdin),rS==rT)?EOF:*rS++;} inline int rd(){ char c=gc(); while(c<48||c>57)c=gc(); int x=c&15; for(c=gc();c>=48&&c<=57;c=gc())x=(x<<3)+(x<<1)+(c&15); return x; } int r[N],sz=1,l=0; ll val[50005]; struct cnum{ double x,y; cnum(){} cnum(double x,double y):x(x),y(y){} inline cnum operator +(const cnum &b)const{return cnum(x+b.x,y+b.y);} inline cnum operator -(const cnum &b)const{return cnum(x-b.x,y-b.y);} inline cnum operator *(const cnum &b)const{return cnum(x*b.x-y*b.y,x*b.y+y*b.x);} }a[N],b[N]; inline ll Min(ll a,ll b){return a<b?a:b;} inline void FFT(cnum *a,short type){ int i,j,M,R; cnum w,wn,t,tt; for(i=0;i<sz;++i)if(i<r[i]){t=a[i];a[i]=a[r[i]];a[r[i]]=t;} for(M=1;M<sz;M<<=1){ wn=cnum(cos(pi/M),sin(pi/M)*type); for(i=0,R=M<<1;i<sz;i+=R){ w=cnum(1,0); for(j=0;j<M;++j){ t=a[i+j];tt=a[i+M+j]*w; a[i+j]=t+tt;a[i+M+j]=t-tt; w=w*wn; } } } } int main(){ int n=rd(),m=rd(),i,x; ll sa2=0ll,sa=0ll,sb2=0ll,sb=0ll,res=0ll,ans=0x3f3f3f3f3f3f3f3f; for(i=1;i<=n;++i){ sa+=(x=rd());sa2+=x*x; a[(n<<1)-i+1]=a[n-i+1]=cnum(x,0); } for(i=1;i<=n;++i){ sb+=(x=rd());sb2+=x*x; b[i]=cnum(x,0); } for(;sz<=n*3;sz<<=1)++l; for(i=0;i<sz;++i)r[i]=(r[i>>1]>>1)|((i&1)<<l-1); FFT(a,1);FFT(b,1); for(i=0;i<sz;++i)a[i]=a[i]*b[i]; FFT(a,-1); for(i=0;i<sz;++i)val[i]=(ll)(a[i].x/sz+0.5); for(i=n+1;i<=(n<<1);++i)if(val[i]>res)res=val[i]; for(i=-m;i<=m;++i)ans=Min(ans,sa2+sb2+(ll)n*i*i+(sa-sb)*(i<<1)-(res<<1)); printf("%lld",ans); return 0; }