bzoj4332[JSOI2012]分零食
一下午被这题的精度续掉了...首先可以找出一个多项式的等比数列的形式,然后类似poj的Matrix Series,不断倍增就可以了.用复数点值表示进行多次的多项式运算会刷刷地炸精度...应当用int存多项式,然后卷积的时候再dft成复数,卷积之后idft回实数.注意两个m次的多项式卷积之后会变成2m次的多项式,多项式的后一半需要清零.
#include<cstdio> #include<cstring> #include<cmath> #include<algorithm> using namespace std; const int maxn=1024*32*4; #define double long double const double pi=acos(-1); struct comp{ double x,y; comp(){} comp(double a,double b){x=a;y=b;} comp operator +(const comp &a){return comp(x+a.x,y+a.y);} comp operator -(const comp &a){return comp(x-a.x,y-a.y);} comp operator *(const comp &a){return comp(x*a.x-y*a.y,x*a.y+y*a.x);} } ;//a:存储原始多项式 b:存储原始多项式的卷积 c:存储答案 d:存储原多项式的n/2次方 int mod; int a[maxn],c[maxn],d[maxn],e[maxn]; void fft(comp* a,int n,int sign){ for(int i=1,j=0,k=n;i<n;++i,k=n){ do j^=(k>>=1);while(j<k);if(i<j)swap(a[i],a[j]); } for(int j=2;j<=n;j<<=1){ int m=j>>1;comp wn(cos(pi*2/j),sign*sin(pi*2/j)); for(comp *p=a;p!=a+n;p=p+j){ comp w(1,0); for(int k=0;k<m;++k,w=w*wn){ comp t=p[m+k]*w;p[m+k]=p[k]-t;p[k]=p[k]+t; } } } if(sign==-1){ for(int i=0;i<n;++i)a[i].x/=n; } } int N=1;int m; int mo(double x){ return (((int)floor(x+0.5))%mod+mod)%mod; } void mult(int *a,int *b,int *res){ static comp tmp1[maxn],tmp2[maxn]; for(int i=0;i<N;++i)tmp1[i]=comp(a[i],0),tmp2[i]=comp(b[i],0); fft(tmp1,N,1);fft(tmp2,N,1); for(int i=0;i<N;++i)tmp1[i]=tmp1[i]*tmp2[i]; fft(tmp1,N,-1); for(int i=0;i<N;++i)res[i]=mo(tmp1[i].x); } void qsum(int n){ if(n==1){ for(int i=0;i<N;++i)c[i]=a[i]; for(int i=0;i<N;++i)d[i]=a[i]; }else{ qsum(n>>1); mult(c,d,e); //for(int i=0;i<N;++i) //e[i]=c[i]*d[i]+c[i]; for(int i=0;i<N;++i)c[i]=mo(c[i]+e[i]); memset(c+(N>>1),0,sizeof(comp)*(N>>1)); if(n&1){ mult(c,a,e); for(int i=0;i<N;++i)c[i]=mo(a[i]+e[i]); memset(c+(N>>1),0,sizeof(comp)*(N>>1)); } mult(d,d,d); memset(d+(N>>1),0,sizeof(comp)*(N>>1)); if(n&1){ mult(d,a,d); memset(d+(N>>1),0,sizeof(comp)*(N>>1)); } } } int main(){ scanf("%d%d",&m,&mod); int n,o,s,u;scanf("%d%d%d%d",&n,&o,&s,&u); n=min(n,m); for(int i=1;i<=m;++i){ int t=i%mod; a[i]=(o*t*t+s*t+u)%mod; } while(N<=m)N<<=1;N<<=1; qsum(n);printf("%d\n",c[m]); return 0; }