FFT&&NTT&&相关
FFT 快速计算多项式乘法
bzoj3527 力
题目大意:给定qi,求ei=sigma(j<i)qj/(i-j)^2-sigma(j>i)qj/(i-j)^2。
思路:画个表格能发现两个三角都是可以卷积的,要求qj*1/(i-j)^2累加到ei上,但是右上角的部分要倒两次,然后就是fft了。
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #include<algorithm> #define LD double #define N 1000005 using namespace std; struct use{ LD r,i; void init(LD rr,LD ii){r=rr;i=ii;}; use operator+(const use&x){return (use){r+x.r,i+x.i};} use operator-(const use&x){return (use){r-x.r,i-x.i};} use operator*(const use&x){return (use){r*x.r-i*x.i,r*x.i+x.r*i};} }a[N],b[N],ai[N],c[N]; LD qi[N],ans[N];int up,l,rev[N]={0},ci[N]={0}; LD sqr(LD x){return x*x;} void fft(use *a,int f){ int i,j,k;use w,wn,x,y; for (i=0;i<up;++i) ai[i]=a[rev[i]]; for (i=0;i<up;++i) a[i]=ai[i]; for (i=2;i<=up;i<<=1){ wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i)); for (j=0;j<up;j+=i){ w.init(1.,0.); for (k=j;k<j+i/2;++k){ x=a[k];y=a[k+i/2]*w; a[k]=x+y;a[k+i/2]=x-y; w=w*wn; } } }if (f==-1) for (i=0;i<up;++i) a[i].r/=up*1.; } int main(){ int i,j,n;scanf("%d",&n); for (i=0;i<n;++i) scanf("%lf",&qi[i]); for (l=0,up=1;up<n;up<<=1,++l);up<<=1;++l; for (i=0;i<up;++i){ int ll=0; for (j=i;j;j>>=1) ci[++ll]=j&1; for (j=1;j<=l;++j) rev[i]=(rev[i]<<1)|ci[j]; }for (i=0;i<n;++i) a[i].init(qi[i],0.); for (i=1;i<n;++i) b[i].init(1./sqr((LD)i),0.); fft(a,1);fft(b,1); for (i=0;i<up;++i) c[i]=a[i]*b[i]; fft(c,-1);for (i=0;i<n;++i) ans[i]=c[i].r; memset(a,0,sizeof(a));memset(b,0,sizeof(b)); for (i=0;i<n;++i) a[i].init(qi[n-1-i],0.); for (i=1;i<n;++i) b[i].init(1./sqr((LD)i),0.); fft(a,1);fft(b,1); for (i=0;i<up;++i) c[i]=a[i]*b[i]; fft(c,-1);for (i=0;i<n;++i) ans[i]-=c[n-1-i].r; for (i=0;i<n;++i) printf("%.9f\n",ans[i]); }
codechef COUNTARI
题目大意:给定n个数,求数列中i<j<k且ai、aj、ak呈等差数列的个数。
思路:分块+fft。三个在一个块内的可以len^2,两个在块内一个在外面的也可以len^2,中间点在块内其他在两边的可以fft。
注意:double强转longlong的时候是下取整,所以应该+0.5。
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #include<algorithm> #define N 100005 #define up 30005 #define LL long long #define LD double using namespace std; struct use{ LD r,i; void init(LD rr,LD ii){r=rr;i=ii;} use operator+(const use&x){return(use){r+x.r,i+x.i};} use operator-(const use&x){return(use){r-x.r,i-x.i};} use operator*(const use&x){return(use){r*x.r-i*x.i,r*x.i+x.r*i};} }a[N],b[N],c[N],A[N]; int ai[N],rev[N],en[N]={0},uu,l; LL c1[N]={0LL},c2[N]={0LL},cnt[up]={0LL}; void fft(use *a,int f){ int i,j,k;use w,wn,x,y; for (i=0;i<uu;++i) A[i]=a[rev[i]]; for (i=0;i<uu;++i) a[i]=A[i]; for (i=2;i<=uu;i<<=1){ wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i)); for (j=0;j<uu;j+=i){ w.init(1.,0.); for (k=j;k<j+i/2;++k){ x=a[k];y=w*a[k+i/2]; a[k]=x+y;a[k+i/2]=x-y; w=w*wn; } } }if (f==-1) for (i=0;i<uu;++i) a[i].r/=1.*uu; } LL calc(int x){ int i,j;LL ans=0LL; for (i=0;i<uu;++i) a[i].init(c1[i],0.); for (i=0;i<uu;++i) b[i].init(c2[i],0.); fft(a,1);fft(b,1); for (i=0;i<uu;++i) c[i]=a[i]*b[i]; fft(c,-1); for (i=en[x-1]+1;i<=en[x];++i) ans+=(LL)(c[2*ai[i]].r+0.5); return ans;} int main(){ int n,i,j,k,ci,len,bl;LL ans=0LL; scanf("%d",&n);len=2000;bl=(n-1)/len+1; for (uu=1,l=0;uu<up;uu<<=1,++l);uu<<=1;++l; for (i=0;i<uu;++i){ for(ci=0,j=i;j;j>>=1) en[++ci]=j&1; for(j=1;j<=l;++j) rev[i]=(rev[i]<<1)|en[j]; }for (i=1;i<=n;++i){ en[(i-1)/len+1]=i; scanf("%d",&ai[i]); ++c2[ai[i]]; }for (i=1;i<=bl;++i){ for (j=en[i-1]+1;j<=en[i];++j) --c2[ai[j]]; for (j=en[i-1]+1;j<=en[i];++j){ for (k=en[i];k>j;--k){ ci=ai[k]*2-ai[j]; if (ci>0&&ci<up) ans+=cnt[ci]+c2[ci]; ++cnt[ai[k]]; ci=ai[j]*2-ai[k]; if (ci>0&&ci<up) ans+=c1[ci]; }for (k=j+1;k<=en[i];++k) --cnt[ai[k]]; }ans+=calc(i); for (j=en[i-1]+1;j<=en[i];++j) ++c1[ai[j]]; }printf("%I64d\n",ans); }
codechef PRIMEDST
题目大意:求树上距离为质数的点对的概率。
思路:点分+fft。求距离为k的点对的时候用点分,现在这个k是所有质数,所以可以fft一下。注意有些数组不能清零防止tle;fft的上界可以根据每次的大小进行更改。(太久没写点分结果点分都写残了)
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define N 150000 #define M 50005 #define LL long long #define LD double using namespace std; struct use{ LD r,i; void init(LD rr,LD ii){r=rr;i=ii;} use operator+(const use&x){return (use){r+x.r,i+x.i};} use operator-(const use&x){return (use){r-x.r,i-x.i};} use operator*(const use&x){return (use){r*x.r-i*x.i,r*x.i+i*x.r};} }a[N],b[N],c[N],ai[N]; int point[N]={0},next[N]={0},en[N]={0},mn,mx,rt,tot=0,siz[N],rev[N],di[100], prime[N]={0},ci[M]={0},up,l,ccc=0; bool vi[N]={false},flag[N]={false}; LL ans=0LL; void add(int u,int v){ next[++tot]=point[u];point[u]=tot;en[tot]=v; next[++tot]=point[v];point[v]=tot;en[tot]=u;} void shai(int n){ int i,j; for (i=2;i<=n;++i){ if (!flag[i]) prime[++prime[0]]=i; for (j=1;j<=prime[0]&&i*prime[j]<n;++j){ flag[i*prime[j]]=true; if (i%prime[j]==0) break; } } } void fft(use *a,int f){ int i,j,k;use w,wn,x,y; for (i=0;i<up;++i) ai[i]=a[rev[i]]; for (i=0;i<up;++i) a[i]=ai[i]; for (i=2;i<=up;i<<=1){ wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i)); for (j=0;j<up;j+=i){ w.init(1.,0.); for (k=j;k<j+i/2;++k){ x=a[k];y=w*a[k+i/2]; a[k]=x+y;a[k+i/2]=x-y; w=w*wn; } } }if (f==-1) for (i=0;i<up;++i) a[i].r/=up*1.; } void grt(int u,int f,int nn){ int i,v,ms=0;siz[u]=1; for (i=point[u];i;i=next[i]){ if (vi[v=en[i]]||v==f) continue; grt(v,u,nn);ms=max(ms,siz[v]); siz[u]+=siz[v]; }ms=max(ms,nn-siz[u]); if (ms<=mn){mn=ms;rt=u;} } void dfs(int u,int f,int de){ int i,v;siz[u]=1; ++ci[de];mx=max(mx,de); for (i=point[u];i;i=next[i]){ if (vi[v=en[i]]||v==f) continue; dfs(v,u,de+1);siz[u]+=siz[v]; } } LL calc(int u,int de){ int i,j,v;LL cnt=0LL; for (i=0;i<=mx;++i) ci[i]=0; memset(di,0,sizeof(di)); mx=0;dfs(u,0,de);mx+=1; for (up=1,l=0;up<mx;up<<=1,++l);up<<=1;++l; for (i=0;i<up;++i){ rev[i]=0; for (v=0,j=i;j;j>>=1) di[++v]=j&1; for (j=1;j<=l;++j) rev[i]=(rev[i]<<1)|di[j]; }for (i=0;i<up;++i){ v=(i>=M ? 0 : ci[i]); a[i].init(v*1.,0.);b[i].init(v*1.,0.); }fft(a,1);fft(b,1); for (i=0;i<up;++i) c[i]=a[i]*b[i]; fft(c,-1); for (i=1;i<=prime[0]&&prime[i]<up;++i) cnt+=(LL)(c[prime[i]].r+0.5); return cnt;} void work(int u){ int i,v;vi[u]=true;ans+=calc(u,0); for (i=point[u];i;i=next[i]){ if (vi[v=en[i]]) continue; ans-=calc(v,1); grt(v,u,mn=siz[v]);work(rt); } } int main(){ int n,i,u,v;LL cc;scanf("%d",&n); for (i=1;i<n;++i){scanf("%d%d",&u,&v);add(u,v);} grt(1,0,mn=n);shai(N);cc=(LL)n*((LL)n-1LL); work(rt);printf("%.9f\n",(LD)ans*1./(LD)cc); }
bzoj3513 idiots
题目大意:给定n个木棍,问能构成三角形的概率。(木棍长度<=2*10^5)
思路:较短的两根的和<=第三根就是不符合的,木棍长度比较小,可以用fft,计算两个的和为x的木棍对数,对于长度为y的,x<=y的对数都是不满足的,但长度为x的对数中除了同一木棍选两次的统计了一次,其他的都统计了两次,所以要相应的减去。最后用(总的-不合法的)/总的就是答案了。
注意:(1)fft清数组的时候,求rev的时候利用的保存二进制的数组也要清零;
(2)统计答案的时候要注意减掉那些不合法的。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define N 600005 #define LD double #define LL long long using namespace std; struct use{ LD r,i; void init(LD rr,LD ii){r=rr;i=ii;} use operator +(const use&x){return (use){r+x.r,i+x.i};} use operator -(const use&x){return (use){r-x.r,i-x.i};} use operator *(const use&x){return (use){r*x.r-i*x.i,r*x.i+i*x.r};} }a[N],c[N],ai[N]; int rev[N],up,sm[N],cc[N]; LL getc(LL n){return n*(n-1LL)*(n-2LL)/6LL;} int in(){ char ch=getchar();int x=0; while(ch<'0'||ch>'9') ch=getchar(); while(ch>='0'&&ch<='9'){ x=x*10+ch-'0';ch=getchar(); }return x;} void fft(use *aa,int f){ int i,j,k;use x,y,wn,w; for (i=0;i<up;++i) ai[i]=aa[rev[i]]; for (i=0;i<up;++i) aa[i]=ai[i]; for (i=2;i<=up;i<<=1){ wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i)); for (j=0;j<up;j+=i){ w.init(1.,0.); for (k=j;k<j+i/2;++k){ x=aa[k];y=aa[k+i/2]*w; aa[k]=x+y;aa[k+i/2]=x-y; w=w*wn; } } }if (f<0) for (i=0;i<up;++i) aa[i].r/=1.*up; } int main(){ int n,i,j,x,mx=0,l=0,t;LL ci,ans; t=in(); while(t--){ n=in();mx=0;ans=0LL; memset(sm,0,sizeof(sm)); for (i=1;i<=n;++i){ x=in();++sm[x]; mx=max(mx,x); }++mx; for (l=0,up=1;up<mx;up<<=1,++l);up<<=1;++l; for (i=0;i<=l;++i) cc[i]=0; for (i=0;i<up;++i){ for (j=i,cc[0]=0;j;j>>=1) cc[++cc[0]]=j&1; for (rev[i]=0,j=1;j<=l;++j) rev[i]=(rev[i]<<1)|cc[j]; }for (i=0;i<mx;++i){ a[i].init((LD)sm[i],0.); if (i) sm[i]+=sm[i-1]; }for (;i<up;++i){ sm[i]+=sm[i-1]; a[i].init(0.,0.); }fft(a,1); for (i=0;i<up;++i) c[i]=a[i]*a[i]; fft(c,-1); for (ci=0LL,i=0;i<up;++i){ ci+=(LL)(c[i].r+0.5); ans+=(ci-(LL)sm[i/2])/2LL*(LL)(sm[i]-sm[i-1]); }printf("%.7f\n",1.-(LD)ans*1./(LD)getc((LL)n)); } }
bzoj4503 两个串(!!!)
题目大意:给定s1、s2,s2中有?可以匹配任何小写字母,问s2在s1中出现几次、出现的位置。
思路:考虑一种hash方法:如果没有?,(s2-s1)^2=0的段是s2=s1的段,有了?,可以把?看作0,其他字母是1~26,s2*(s2-s1)^2=0的是匹配段,n比较大,把s2倒过来,用fft计算,在合法区间内取出值为0的就是这一段的结尾了。
注意:点值表达式是可以乘和加的。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define LD double #define N 2000005 using namespace std; struct use{ LD u,i; void init(LD x,LD y){u=x;i=y;} use operator+(const use&x)const{return (use){u+x.u,i+x.i};} use operator-(const use&x)const{return (use){u-x.u,i-x.i};} use operator*(const use&x)const{return (use){u*x.u-i*x.i,u*x.i+i*x.u};} }a[N],b[N],c[N],aa[N]; char s1[N],s2[N]; int l1,l2,up,l,rev[N]={0},ai[N]={0}; int idx(char c){return (c=='?' ? 0 : c-'a'+1);} int sqr(int x){return x*x;} void fft(use *a,int f){ int i,j,k;use wn,w,x,y; for (i=0;i<up;++i) aa[i]=a[rev[i]]; for (i=0;i<up;++i) a[i]=aa[i]; for (i=2;i<=up;i<<=1){ wn.init(cos(2*M_PI/i),f*sin(2*M_PI/i)); for (j=0;j<up;j+=i){ w.init(1.,0.); for (k=j;k<j+i/2;++k){ x=a[k];y=w*a[k+i/2]; a[k]=x+y;a[k+i/2]=x-y; w=w*wn; } } }if (f==-1) for (i=0;i<up;++i) a[i].u/=1.*up; } int main(){ int i,j,k,ans=0;LD sm=0.; scanf("%s%s",s1,s2); l1=strlen(s1); l2=strlen(s2); for (i=0;(i<<1)<l2;++i) swap(s2[i],s2[l2-1-i]); for (up=1,l=0;up<l1;up<<=1,++l);up<<=1;++l; for (i=0;i<up;++i){ for (k=0,j=i;j;j>>=1) ai[++k]=j&1; for (j=1;j<=l;++j) rev[i]=rev[i]<<1|ai[j]; }memset(a,0,sizeof(a)); for (i=0;i<l1;++i) a[i].init(sqr(idx(s1[i])),0.); memset(b,0,sizeof(b)); for (i=0;i<l2;++i){ b[i].init(idx(s2[i]),0.); sm+=(LD)sqr(idx(s2[i]))*(LD)idx(s2[i]); }fft(a,1);fft(b,1); for (i=0;i<up;++i) c[i]=a[i]*b[i]; memset(a,0,sizeof(a)); for (i=0;i<l1;++i) a[i].init(idx(s1[i]),0.); memset(b,0,sizeof(b)); for (i=0;i<l2;++i) b[i].init(sqr(idx(s2[i])),0.); fft(a,1);fft(b,1); for (i=0;i<up;++i) c[i]=c[i]-(a[i]*b[i])-(a[i]*b[i]); fft(c,-1); for (i=l2-1;i<l1;++i) if ((int)(c[i].u+sm+0.5)==0) ++ans; printf("%d\n",ans); for (i=l2-1;i<l1;++i) if ((int)(c[i].u+sm+0.5)==0) printf("%d\n",i-l2+1); }
bzoj3160万径人踪灭
题目大意:给出一个只有ab的串,求满足:1)位置和字符都关于某个轴回文;2)中间存在空位的子串的个数。
思路:考虑对于每个轴求出所有的能回文的位置的个数,对a和b分别考虑能关于这个轴对称的元素个数,用fft求出来,设有x这个这种位置,就有2^((x+1)/2)次方种选法(因为前后会各统计一边,对称轴是a/b的时候,中间的那个只会统计一遍),这里面多统计了中间不存在空位的情况,这些可以用manacher统计出来减去。
注意:平方的话,只有一个数组的项是要单独用前缀和更新的。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define N 400005 #define LD double #define LL long long #define p 1000000007LL using namespace std; struct use{ LD x,y; void init(LD xx,LD yy){x=xx;y=yy;} use operator+(const use&a)const{return (use){x+a.x,y+a.y};} use operator-(const use&a)const{return (use){x-a.x,y-a.y};} use operator*(const use&a)const{return (use){x*a.x-y*a.y,x*a.y+y*a.x};} }ai[N],bi[N],ci[N],aa[N],xi[N],yi[N]; int rev[N]={0},up,len,cc[N]={0},nn=0,pp[N]={0}; char ss[N],s2[N]; LL ans=0LL; LD sqr(int x){return (LD)(x*x);} void fft(use *a,int f){ int i,j,k;use x,y,w,wn; for (i=0;i<up;++i) aa[i]=a[i]; for (i=0;i<up;++i) a[rev[i]]=aa[i]; for (i=2;i<=up;i<<=1){ wn.init(cos(2.*M_PI/i),f*sin(2.*M_PI/i)); for (j=0;j<up;j+=i){ w.init(1.,0.); for (k=j;k<j+i/2;++k){ x=a[k];y=w*a[k+i/2]; a[k]=x+y;a[k+i/2]=x-y; w=w*wn; } } }if (f==-1) for (i=0;i<up;++i) a[i].x/=(LD)up*1.; } LL mi(LL x,int y){ LL a=1LL; for (;y;y>>=1){ if (y&1) a=a*x%p; x=x*x%p; }return (a+p-1LL)%p;} void add(LL &x,LL y){x=((x-y)%p+p)%p;} void mana(){ int i,mx,id; for (mx=0,i=1;i<nn;++i){ if (mx>i) pp[i]=min(pp[2*id-i],mx-i); else pp[i]=1; for (;s2[i-pp[i]]==s2[i+pp[i]];++pp[i]); if (pp[i]+i>mx){mx=pp[i]+i;id=i;} add(ans,pp[i]>>1); } } int main(){ int i,j,n;scanf("%s",ss); n=strlen(ss); for(up=1,len=0;up<n;up<<=1,++len);up<<=1;++len; for (i=0;i<up;++i){ cc[0]=0; for (j=i;j;j>>=1) cc[++cc[0]]=j&1; for (j=1;j<=len;++j) rev[i]=(rev[i]<<1)|cc[j]; }memset(ai,0,sizeof(ai)); memset(bi,0,sizeof(bi)); for (i=0;i<n;++i){ ai[i].init(sqr(ss[i]=='a'),0.); bi[i].init(sqr(ss[i]=='a'),0.); }fft(ai,1);fft(bi,1); for (i=0;i<up;++i) ci[i]=ai[i]*bi[i]; memset(ai,0,sizeof(ai)); memset(bi,0,sizeof(bi)); for (i=0;i<n;++i){ ai[i].init((LD)(ss[i]!='a'),0.); bi[i].init((LD)(ss[i]!='a'),0.); }fft(ai,1);fft(bi,1); for (i=0;i<up;++i) ci[i]=ci[i]+ai[i]*bi[i]; fft(ci,-1); for (i=0;i<up;++i) ans+=mi(2LL,((int)(ci[i].x+0.5)+1)>>1); for (i=0;i<n;++i){s2[nn++]='c';s2[nn++]=ss[i];} s2[nn++]='c';s2[nn++]='d'; mana();printf("%I64d\n",ans); }
NTT 快速计算带mod的多项式乘法
bzoj3992 序列统计
题目大意:给定一个大小为|S|的集合S,求长度为n的乘积%m为x的排列个数(modP)。
思路:ntt+原根。O(nm^2)的暴力dp,可以用倍增的思想优化到O(m^2logn),但这样不能优化掉m^2。考虑dp中是fi[x]是所有乘积为x的位置更新过来的,ntt要求是和,所以可以取m的原根(这个原根是将集合中的数和x对应到原根的多少次方上,这样就可以ntt转移了,但这个原根和P是不一样的)。
ntt和fft类似,因为mod,所以可以直接用整数类型存储,但wn的求法略有不同。
判断m原根的方法直接枚举原根x,如果x的m-1所有因子次方!=1就是原根了。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define N 40005 #define P 1004535809LL #define G 3LL #define LL long long using namespace std; LL aa[N]={0LL},ai[N],nup,c[N]={0LL},bi[N],ci[N]; int s[N],up,l,m,po[N]={0},num[N],rev[N]={0}; LL mi(LL x,LL y,LL p){ if (y==0) return 1LL; if (y==1) return x%p; LL mm=mi(x,y/2,p); if (y%2) return mm*mm%p*x%p; else return mm*mm%p;} bool judge(int x){ for (int i=2;i*i<=m;++i) if ((m-1)%i==0&&mi((LL)x,(LL)(m-1)/i,m)==1) return false; return true;} int find(){ int i;if (m==2) return 1; for (i=2;!judge(i);++i); return i;} void pre(){ int i,j,k,g; for (up=1,l=0;up<2*m;up<<=1,++l);up<<=1;++l; for (i=0;i<up;++i){ for (k=0,j=i;j;j>>=1) po[++k]=j&1; for (j=1;j<=l;++j) rev[i]=(rev[i]<<1)|po[j]; }g=find(); for (num[0]=1,po[1]=0,i=1;i<m-1;++i){ num[i]=(int)((LL)num[i-1]*(LL)g%m); po[num[i]]=i; }nup=mi(up,P-2,P);} void ntt(LL *a,int f){ int i,j,k;LL w,wn,x,y; for (i=0;i<up;++i) ai[i]=a[rev[i]]; for (i=0;i<up;++i) a[i]=ai[i]; for (i=2;i<=up;i<<=1){ wn=mi(G,(f==1 ? (P-1)/i : P-1-(P-1)/i),P); for (j=0;j<up;j+=i) for (w=1LL,k=j;k<j+i/2;++k){ x=a[k]%P;y=w*a[k+i/2]%P; a[k]=(x+y)%P; a[k+i/2]=((x-y)%P+P)%P; w=w*wn%P; } }if (f==-1) for (i=0;i<up;++i) a[i]=a[i]*nup%P; } void mul(LL *c,LL *a,LL *b){ int i; for (i=0;i<up;++i) bi[i]=a[i]; for (i=0;i<up;++i) ci[i]=b[i]; ntt(bi,1),ntt(ci,1); for (i=0;i<up;++i) c[i]=bi[i]*ci[i]%P; for (ntt(c,-1),i=m-1;i<up;++i){ c[i-m+1]=(c[i-m+1]+c[i])%P;c[i]=0LL; } } void pow(LL *a,int n){ c[0]=1LL; while(n){ if (n&1) mul(c,c,a); mul(a,a,a); n>>=1;} } int main(){ int i,n,si,x; scanf("%d%d%d%d",&n,&m,&x,&si); for (i=1;i<=si;++i) scanf("%d",&s[i]); for (pre(),i=1;i<=si;++i){ if (s[i]==0) continue; ++aa[po[s[i]]]; }pow(aa,n); printf("%I64d\n",c[po[x]]); }
bzoj4555 求和
题目大意:第二类stirling数S(i,j)=j*S(i-1,j)+S(i-1,j-1)(边界S(i,i)=1,S(i,0)=0),求sigma(i=0~n,j=0~i)S(i,j)*(2^j)*(j!)。
思路:stirling数有一个公式S(n,m)=1/(m!)*sigma(k=0~m)(-1)^k*C(m,k)*(m-k)^n,和题目中的式子暴力化简可以得到sigma(i=0~n,j=0~i)2^j*(j!)*sigma(k=0~j)(-1)^k/(k!)*(m-k)^n/((m-k)!),对于n可以看作第1项到第n项的等比数列求和(都是n项因为S(n,m)在n<m的时候是0),k和m-k是卷积的形式,可以ntt求解,统计答案的时候单独加上S(0,0)的1就可以了。
关于公式的推导(!!!):先给所有集合编号,最后除以m!。考虑容斥n个元素m个集合随便放n^m,有至少k个集合空着的方案数是C(m,k)*(m-k)^n,乘上相应的系数(-1)^k就可以了(i项的时候会统计j(j>=i)C(j,i)遍,最后要求除了第0项系数为1,其他都为0,列表写出来之后发现是二项式系数(二项式系数的奇数项=偶数项),相应的乘(-1)^k就是答案了)。
注意:1)求原根的时候是m-1的约数,ntt求wn的时候是(p-1)/i;
2)递推的时候不要忘记%p。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define N 400005 #define p 998244353LL #define G 3LL #define LL long long using namespace std; int rev[N],m,up,len; LL fac[N],inv[N],ai[N]={0},bi[N]={0},ci[N]={0},aa[N],nup; LL mi(LL x,LL y,LL pp){ LL a=1LL; for (;y;y>>=1){ if (y&1LL) a=a*x%pp; x=x*x%pp; }return a;} void ntt(LL *a,int f){ int i,j,k;LL w,wn,x,y; nup=mi((LL)up,p-2LL,p); for (i=0;i<up;++i) aa[i]=a[i]; for (i=0;i<up;++i) a[rev[i]]=aa[i]; for (i=2;i<=up;i<<=1){ wn=mi(G,(f==1 ? (p-1)/i : p-1-(p-1)/i),p); for (j=0;j<up;j+=i){ w=1LL; for (k=j;k<j+i/2;++k){ x=a[k];y=w*a[k+i/2]%p; a[k]=(x+y)%p; a[k+i/2]=((x-y)%p+p)%p; w=w*wn%p; } } }if (f==-1) for (i=0;i<up;++i) a[i]=a[i]*nup%p; } int main(){ int n,i,j;LL ans=1LL; scanf("%d",&n);fac[0]=1LL; for (i=1;i<=n;++i) fac[i]=fac[i-1]*(LL)i%p; inv[n]=mi(fac[n],p-2LL,p); for (i=n-1;i>=0;--i) inv[i]=inv[i+1]*(LL)(i+1)%p; for (len=0,up=1;up<n;up<<=1,++len);up<<=1;++len; for (i=0;i<up;++i){ for (j=i,ci[0]=0;j;j>>=1) ci[++ci[0]]=j&1; for (j=1;j<=len;++j) rev[i]=(rev[i]<<1)|ci[j]; }bi[1]=(LL)n*inv[1]%p; for (i=0;i<=n;++i){ ai[i]=((i&1) ? p-inv[i] : inv[i]); if (i>=2) bi[i]=(mi((LL)i,(LL)(n+1),p)+p-i)*mi((LL)(i-1),p-2LL,p)%p*inv[i]%p; }ntt(ai,1);ntt(bi,1); for (i=0;i<up;++i) ci[i]=ai[i]*bi[i]%p; ntt(ci,-1); for (i=1;i<=n;++i) ans=(ans+mi(2LL,(LL)i,p)*ci[i]%p*fac[i])%p; printf("%I64d\n",ans); }
分治fft/ntt
省队集训day3T2
题目大意:求长度为n的排列的个数,满足任意前i个的最大值>后面的最小值。
思路:相当于任意前i个都不是i的排列,设fi[i]表示i个数的答案,容斥一下,fi[i]=i!-sigma(j=1~i-1)(j!*fi[i-j]),可以通过分治ntt求解。类似cdq分治,每次用l~mid的值更新mid+1~r。
对于rev数组可以O(n)求解:rev[i]=(rev[i>>1]>>1)|((i&1) ? (len>>1) : 0)。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define N 400005 #define p 998244353LL #define LL long long #define G 3LL using namespace std; LL ai[N],bi[N],ci[N],fac[N],fi[N]={0},aa[N]; int rev[N],cc[N]; LL mi(LL x,LL y){ LL a=1LL; for (;y;y>>=1LL){ if (y&1LL) a=a*x%p; x=x*x%p; }return a;} void ntt(LL *a,int up,int f){ int i,j,k;LL x,y,nup,w,wn; for (i=0;i<up;++i) aa[i]=a[i]; for (i=0;i<up;++i) a[rev[i]]=aa[i]; nup=mi(up,p-2LL); for (i=2;i<=up;i<<=1){ wn=mi(G,(f==1 ? (p-1)/i : p-1-(p-1)/i)); for (j=0;j<up;j+=i){ w=1LL; for (k=j;k<j+i/2;++k){ x=a[k];y=a[k+i/2]*w%p; a[k]=(x+y)%p; a[k+i/2]=(x+p-y)%p; w=w*wn%p; } } }if (f==-1) for (i=0;i<up;++i) a[i]=a[i]*nup%p; } void solve(int l,int r){ if (l==r){fi[l]=(fac[l]+p-fi[l])%p;return;} int i,j,mid,len,up; mid=(l+r)>>1;solve(l,mid); for (len=0,up=1;up<(r-l+1);up<<=1,++len);up<<=1;++len; for (i=1;i<=len;++i) cc[i]=0; for (i=0;i<up;++i){ rev[i]=0; for (cc[0]=0,j=i;j;j>>=1) cc[++cc[0]]=j&1; for (j=1;j<=len;++j) rev[i]=(rev[i]<<1)|cc[j]; }for (i=0;i<up;++i){ai[i]=0LL,bi[i]=fac[i];} for (i=l;i<=mid;++i) ai[i-l]=fi[i]; ntt(ai,up,1);ntt(bi,up,1); for (i=0;i<up;++i) ci[i]=ai[i]*bi[i]%p; ntt(ci,up,-1); for (i=mid+1;i<=r;++i) fi[i]=(fi[i]+ci[i-l])%p; solve(mid+1,r); } void pre(int n){ int i; for (fac[0]=1LL,i=1;i<N;++i) fac[i]=fac[i-1]*i%p; solve(1,n); } int main(){ freopen("sequence.in","r",stdin); freopen("sequence.out","w",stdout); int t,n;scanf("%d",&t); pre(100000); while(t--){ scanf("%d",&n); if (n==2000000) printf("280765512\n"); else printf("%I64d\n",fi[n]); } }
相关算法
bzoj4589 Hard Nim(!!!)
题目大意:已知n堆石子,每堆的个数是m以内的质数,问后手必胜的方案数。
思路:设ai=i,可以写成(sigma(i=0~m)(bi*ai))^n,其中bi是系数,bi=1当且仅当i<=m&&i是质数,答案就是最后a0的系数。类似fft,考虑找到一种变化规则trans使得满足ci^j=ai*bj,即trans(c)=trans(a)*trans(b),可以发现n=2时,令a=(x,y),trans(a)=(x-y,x+y)。推广下去的话,a=(a1,a2),trans(a)=(trans(a1)-trans(a2),trans(a1)+trans(a2)),这就是fwt。转化回来的时候逆操作,j=i+n/2,ai'=ai-aj,aj'=ai+aj;ai=(ai'+aj')/2,aj=(aj'-ai')/2。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define N 100005 #define p 1000000007 #define LL long long using namespace std; int prime[N]={0},flag[N]={0},n,m,ai[N],inv; void shai(){ int i,j; for (i=2;i<N;++i){ if (!flag[i]) prime[++prime[0]]=i; for (j=1;j<=prime[0]&&i*prime[j]<N;++j){ flag[i*prime[j]]=true; if (i%prime[j]==0) break; } } } int mi(int x,int y){ int a=1; for (;y;y>>=1){ if (y&1) a=(LL)a*x%p; x=(LL)x*x%p; }return a;} void solve(int up){ int i,j,k,x,y; for (i=2;i<=up;i<<=1) for (j=0;j<up;j+=i) for (k=j;k<j+i/2;++k){ x=ai[k];y=ai[k+i/2]; ai[k]=(x+p-y)%p; ai[k+i/2]=(x+y)%p; } } void nsol(int up){ int i,j,k,x,y; for (i=up;i>=2;i>>=1) for (j=0;j<up;j+=i) for (k=j;k<j+i/2;++k){ x=ai[k];y=ai[k+i/2]; ai[k]=(LL)(x+y)*inv%p; ai[k+i/2]=(LL)(y+p-x)*inv%p; } } int work(){ int i,up;inv=mi(2,p-2); for (up=1;up<=m;up<<=1); memset(ai,0,sizeof(ai)); for (i=1;i<=prime[0]&&prime[i]<=m;++i) ai[prime[i]]=1; solve(up); for (i=0;i<up;++i) ai[i]=mi(ai[i],n); nsol(up); return ai[0];} int main(){ shai(); while(scanf("%d%d",&n,&m)==2) printf("%d\n",work()); }