AHOI/HNOI2017 礼物
题目链接:戳我
对于题目中给的式子:(大家暂且把\(y_i\)当作\(y_{i+k}\)来看啦qwq)
\(\sum_{i=1}^{n}(x_i-(y_i+c))^2\)
\(=\sum_{i=1}^n x_i-2x_i(y_i+c)+(y_i+c)^2\)
\(=\sum_{i=1}^nx_i^2-2x_iy_i-2x_ic+y_i^2+2y_ic+c^2\)
\(=\sum_{i=1}^{n} x_i^2-\sum_{i=1}^{n} 2x_iy_i-\sum_{i=1}^n2x_ic+\sum_{i=1}^ny_i^2+\sum_{i=1}^n2y_ic+\sum_{i=1}^nc^2\)
现在问题转化成了最大化\(\sum_{i=1}^{n}x_iy_i\)。
然后我们把y反转,我们大概就得到了这样的一个式子:
\(\sum_{i=1}^{n}x_iy_{n-i+1}\)
唔,卷积QAQ
然后设\(f(x)\)和\(g(x)\),\(f(x)\)的第\(i\)项系数是\(x_i\),\(g(x)\)的第\(i\)项系数是\(y_{n-i+1}\),所以\(f(x)∗g(x)\)的第\(n+1\)项系数就是第一个环的第n个和第二个环的第1个重合的结果(当然,也可以是前者的第1个和后者的第n个重合的结果,大家可以手动画图比对一下(注意是逆时针分布哦qwq)因为他们的次数相加等于n+1),然后以此类推,断环为链,所以倍长g,n+1~2n项的最大值即为所求。(注意f是要补位的,不过为了对答案不造成影响,要全部补零)
因为c很小,所以暴力枚举就行啦!
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<complex>
#define MAXN 2000000
#define INF 0x3f3f3f3f
using namespace std;
const double pi=acos(-1.0);
int N,M,l,n,m,ans=INF;
int p1,p2,t1,t2,cur_ans=-INF;
int r[MAXN],s1[MAXN],s2[MAXN],S[MAXN];
complex<double> a[MAXN],b[MAXN];
inline void fft(complex<double> *P,int opt)
{
for(int i=0;i<N;i++)
if(i<r[i])
swap(P[i],P[r[i]]);
for(int i=1;i<N;i<<=1)
{
complex<double> W(cos(pi/i),opt*sin(pi/i));
for(int p=i<<1,j=0;j<N;j+=p)
{
complex<double> w(1,0);
for(int k=0;k<i;k++,w*=W)
{
complex<double> X=P[j+k],Y=w*P[j+k+i];
P[j+k]=X+Y,P[j+k+i]=X-Y;
}
}
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("ce.in","r",stdin);
#endif
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&s1[i]);
for(int i=1;i<=n;i++) scanf("%d",&s2[i]);
N=n-1,M=2*n-1;
for(int i=0;i<=N;i++) a[i]=s1[i+1];
for(int i=0;i<n;i++) b[i]=s2[n-i];
for(int i=0;i<n;i++) b[i+n]=b[i];
M+=N;
for(N=1;N<=M;N<<=1) l++;
for(int i=0;i<N;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
fft(a,1),fft(b,1);
for(int i=0;i<N;i++) a[i]=a[i]*b[i];
fft(a,-1);
for(int i=0;i<=M;i++) S[i]=(int)(a[i].real()/N+0.5);
for(int i=1;i<=n;i++)
p1+=s1[i]*s1[i],p2+=s2[i]*s2[i],t1+=s1[i],t2+=s2[i];
for(int i=n-1;i<=2*n-1;i++) cur_ans=max(cur_ans,S[i]);
for(int c=-m;c<=m;c++)
{
int ansans=p1+p2+n*c*c+2*c*(t1-t2)-2*cur_ans;
ans=min(ans,ansans);
}
printf("%d\n",ans);
return 0;
}