【xsy1131】tortue FFT
题目大意:
一次游戏要按N个按键。每个按键阿米巴有P[i]的概率按错。对于一串x个连续按对的按键,阿米巴可以得分
$f(x)=tan(\dfrac{x}{N})\times e^{arcsin(0.8\times \frac{x}{N})}\times N$
在阿米巴疯狂的玩这款游戏之前,小强想知道,阿米巴的期望得分是多少。
数据范围:$n≤10^5$
貌似题解是泰勒展开,然而我自己想的做法是分治FFT,最后写了个没有分治的FFT。
我们记$S_i$表示连续按下至少$i$个按键的期望次数。
那么答案显然为$\sum_{i=1}^{n}S_i-2S_{i+1}+S_{i+2}$
考虑如何求$S$,不难发现$S_i=\sum_{j=1}^{n-i+1}\prod_{k=0}^{i-1}(1-P[j+k])$
直接求显然是$O(n^3)$的,通过前缀积优化一发可以做到$O(n^2)$
我们构造序列$F$和序列$G$,令$F_i=\prod_{j=1}^{i}(1-P[j])$,$G_i=\dfrac{1}{\prod_{j=1}^{i-1}(1-P[n-j])}$。
不难发现,$S_i=\sum_{j=1}^{i}F_{j}G_{i-j+1}$
我们用FFT加速一波就可以求了
时间复杂度:$O(n\log\ n)$。
PS:此题卡精度,不要尝试使用两次FFT做卷积,要用三次FFT!!!
1 #include<bits/stdc++.h> 2 #define M (1<<18) 3 #define PI acos(-1) 4 using namespace std; 5 6 struct cp{ 7 double i,r; 8 cp(double R=0,double I=0){i=I; r=R;} 9 friend cp operator +(cp a,cp b){return cp(a.r+b.r,a.i+b.i);} 10 friend cp operator -(cp a,cp b){return cp(a.r-b.r,a.i-b.i);} 11 friend cp operator *(cp a,cp b){return cp(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r);} 12 friend cp operator /(cp a,double b){return cp(a.r/b,a.i/b);} 13 }a[M],b[M]; 14 void change(cp a[],int len){ 15 for(int i=0,j=0;i<len-1;i++){ 16 if(i<j) swap(a[i],a[j]); 17 int k=len>>1; 18 while(j>=k) j-=k,k>>=1; 19 j+=k; 20 } 21 } 22 void FFT(cp a[],int len,int on){ 23 change(a,len); 24 for(int h=2;h<=len;h<<=1){ 25 cp wn=cp(cos(2*PI/h),sin(2*PI/h*on)); 26 for(int j=0;j<len;j+=h){ 27 cp w=cp(1,0); 28 for(int k=j;k<j+(h>>1);k++){ 29 cp u=a[k],t=w*a[k+(h>>1)]; 30 a[k]=u+t; a[k+(h>>1)]=u-t; 31 w=w*wn; 32 } 33 } 34 } 35 if(on==-1){ 36 for(int i=0;i<len;i++) 37 a[i]=a[i]/len; 38 } 39 } 40 41 double p[M]={0},s[M]={0},ans=0;int n; 42 double f(double x){return tan(x/n)*exp(asin(0.8*x/n))*n;} 43 44 int main(){ 45 scanf("%d",&n); 46 for(int i=0;i<n;i++) scanf("%lf",p+i),p[i]=1-p[i]; 47 a[0].r=1; for(int i=0;i<n-1;i++) a[i+1].r=a[i].r/p[i]; 48 b[n-1].r=p[0]; for(int i=n-2;~i;i--) b[i].r=b[i+1].r*p[n-i-1]; 49 int m=1; for(;m<n*2;m<<=1); 50 FFT(a,m,1); FFT(b,m,1); 51 for(int i=0;i<m;i++) a[i]=a[i]*b[i]; 52 FFT(a,m,-1); 53 for(int i=1;i<=n;i++) s[i]=a[n-i].r; 54 for(int i=1;i<=n;i++) ans+=(s[i]-2*s[i+1]+s[i+2])*f(i); 55 printf("%.10lf\n",ans); 56 }