「luogu2414」[AH2017/HNOI2017]礼物
考虑c固定时快速算出所有位置的最小差异值,
把平方拆掉后构造一下发现是个卷积形式,fft即可。
m的范围很小,且答案关于m是一个单峰函数,所以我一开始以为是三分m,算了算复杂度好像很可过的样子,
写出来后发现不开O2只有70分(BZOJ可过)
看了别人的程序才知道c可以直接算出来!!!
其实只要让两个环上的总和最接近,答案就是最小的。
1 #include<bits/stdc++.h> 2 #define R register 3 #define db double 4 using namespace std; 5 const int N=50010,oo=INT_MAX; 6 const db PI=acos(-1); 7 int n,nn,m,x[N<<2],y[N<<2],rev[N<<2],maxp,s; 8 int read(){ 9 int x=0,w=1;char c=0; 10 while(c<'0'||c>'9') c=getchar(); 11 while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar(); 12 return x*w; 13 } 14 struct Com{ 15 db real,image; 16 Com operator+(const Com& k)const{return (Com){real+k.real,image+k.image};} 17 Com operator-(const Com& k)const{return (Com){real-k.real,image-k.image};} 18 Com operator*(const Com& k)const{return (Com){real*k.real-image*k.image,real*k.image+image*k.real};} 19 Com operator/(const int& k)const{return (Com){real/k,image/k};} 20 }fa[N<<2],fb[N<<2],fc[N<<2]; 21 inline void fft(Com* a,int b){ 22 for(R int i=0;i<s;i++) if(rev[i]>i) swap(a[rev[i]],a[i]); 23 for(R int len=2;len<=s;len<<=1){ 24 Com wn=(Com){cos(2.0*PI/len),b*sin(2.0*PI/len)}; 25 for(R int i=0;i<s;i+=len){ 26 Com w=(Com){1,0}; 27 for(R int j=0;j<(len>>1);j++,w=w*wn){ 28 Com a0=a[i+j],a1=w*a[i+j+(len>>1)]; 29 a[i+j]=a0+a1,a[i+j+(len>>1)]=a0-a1; 30 } 31 } 32 } 33 if(b==-1) for(R int i=0;i<s;i++) a[i]=a[i]/s; 34 return; 35 } 36 inline int calc(int k){ 37 int temp=0,ans=oo; 38 for(int j=0;j<s;j++){ 39 fa[j].real=j<n?x[j]+k:0; 40 fb[j].real=j<nn?y[j]:0; 41 fa[j].image=fb[j].image=0; 42 if(j<n) temp+=(x[j]+k)*(x[j]+k)+y[j]*y[j]; 43 } 44 fft(fa,1);fft(fb,1); 45 for(int j=0;j<s;j++) fc[j]=fa[j]*fb[j]; 46 fft(fc,-1); 47 for(int j=n-1;j<nn;j++) ans=min(ans,temp-int(fc[j].real+0.1)*2); 48 return ans; 49 } 50 int main(){ 51 n=read(),m=read(); 52 int sum=0,add; 53 for(int i=0;i<n;i++) x[i]=read(),sum-=x[i]; 54 for(int i=n-1;~i;i--) y[i]=read(),sum+=y[i]; 55 add=(floor)(1.0*sum/n+0.5); 56 nn=(n<<1)-1,s=1; 57 for(int i=n;i<nn;i++) y[i]=y[i-n]; 58 while(s<nn) s<<=1,maxp++; 59 for(int i=1;i<s;i++) rev[i]=(rev[i>>1]>>1)^(((i&1)<<(maxp-1))); 60 for(int i=0;i<n;i++) fa[i].real=x[i]-m-1; 61 printf("%d",calc(add)); 62 return 0; 63 }
三分代码:
1 #include<bits/stdc++.h> 2 #define R register 3 #define db double 4 using namespace std; 5 const int N=50010,oo=INT_MAX; 6 const db PI=acos(-1); 7 int n,nn,m,x[N<<2],y[N<<2],rev[N<<2],maxp,s; 8 int read(){ 9 int x=0,w=1;char c=0; 10 while(c<'0'||c>'9') c=getchar(); 11 while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar(); 12 return x*w; 13 } 14 struct Com{ 15 db real,image; 16 Com operator+(const Com& k)const{return (Com){real+k.real,image+k.image};} 17 Com operator-(const Com& k)const{return (Com){real-k.real,image-k.image};} 18 Com operator*(const Com& k)const{return (Com){real*k.real-image*k.image,real*k.image+image*k.real};} 19 Com operator/(const int& k)const{return (Com){real/k,image/k};} 20 }fa[N<<2],fb[N<<2],fc[N<<2]; 21 void fft(Com* a,int b){ 22 for(R int i=0;i<s;i++) if(rev[i]>i) swap(a[rev[i]],a[i]); 23 for(R int len=2;len<=s;len<<=1){ 24 Com wn=(Com){cos(2.0*PI/len),b*sin(2.0*PI/len)}; 25 for(R int i=0;i<s;i+=len){ 26 Com w=(Com){1,0}; 27 for(R int j=0;j<(len>>1);j++,w=w*wn){ 28 Com a0=a[i+j],a1=w*a[i+j+(len>>1)]; 29 a[i+j]=a0+a1,a[i+j+(len>>1)]=a0-a1; 30 } 31 } 32 } 33 if(b==-1) for(R int i=0;i<s;i++) a[i]=a[i]/s; 34 return; 35 } 36 int calc(int k){ 37 int temp=0,ans=oo; 38 for(int j=0;j<s;j++){ 39 fa[j].real=j<n?x[j]+k:0; 40 fb[j].real=j<nn?y[j]:0; 41 fa[j].image=fb[j].image=0; 42 if(j<n) temp+=(x[j]+k)*(x[j]+k)+y[j]*y[j]; 43 } 44 fft(fa,1);fft(fb,1); 45 for(int j=0;j<s;j++) fc[j]=fa[j]*fb[j]; 46 fft(fc,-1); 47 for(int j=n-1;j<nn;j++) ans=min(ans,temp-int(fc[j].real+0.1)*2); 48 return ans; 49 } 50 int main(){ 51 n=read(),m=read(); 52 for(int i=0;i<n;i++) x[i]=read(); 53 for(int i=n-1;~i;i--) y[i]=read(); 54 nn=(n<<1)-1,s=1; 55 for(int i=n;i<nn;i++) y[i]=y[i-n]; 56 while(s<nn) s<<=1,maxp++; 57 for(int i=1;i<s;i++) rev[i]=(rev[i>>1]>>1)^(((i&1)<<(maxp-1))); 58 for(int i=0;i<n;i++) fa[i].real=x[i]-m-1; 59 int ans=oo; 60 int l=-m,r=m,mid1,mid2,res1,res2; 61 while(l<=r){ 62 mid1=l+(r-l+1)/3,mid2=l+(r-l+1)*2/3; 63 res1=calc(mid1),res2=calc(mid2); 64 ans=min(ans,min(res1,res2)); 65 if(res1<res2) r=mid2-1; 66 else l=mid1+1; 67 } 68 printf("%d",ans); 69 return 0; 70 }