51nod 1387 移数字
回来拉模版的时候意外发现这个题还没写题解,所以就随便补点吧。
题意其实就是要你求n的阶乘在模意义下的值。
首先找出来一个最大的$m$满足$m^2<=n$,对于大于$m^2$部分的数我们直接暴力求就行了,问题是求$m^2$以内的答案。
先构造一个多项式$f(x)=(x+1)(x+2)(x+3)……(x+m)$,然后求它在$x=0、x=m……x=m(m-1)$位置的值,然后求个值全部乘起来就行了。
稍微说下怎么做多点求值,构造两个多项式
$$G_1(x)=(x-x_1)(x-x_2)……(x-x_{\left\lfloor\frac{m}{2}\right\rfloor})=x(x-m)……(x-\left\lfloor\frac{m}{2}\right\rfloor m)$$
$$G_2(x)=(x-x_{\left\lfloor\frac{m}{2}\right\rfloor+1})……(x-x_m)=(x-\left\lfloor\frac{m}{2}\right\rfloor m-m)……(x-m^2+m)$$
然后拿$f(x)$对$G_1(x)$取模得到一个$\left\lfloor\frac{m}{2}\right\rfloor$次的多项式,这个多项式在$x_1、x_2……x_{\left\lfloor\frac{m}{2}\right\rfloor}$位置的值跟$f(x)$是一样的(这是因为构造出来的式子在这些位置都等于0,而我们可以把多项式除法看成很多次减法,所以这个值不会变),后半部分同理用$G_2(x)$处理,这个时候问题规模就减半了,由此递归即可。
题目最后复杂度是$O(\sqrt{n}log^2\sqrt{n})$
#include<cstdio> #include<cstring> #include<algorithm> #define lp (p<<1) #define rp ((p<<1)|1) #define ll long long #define MN 200200 using namespace std; int read_p,read_ca; inline int read(){ read_p=0;read_ca=getchar(); while(read_ca<'0'||read_ca>'9') read_ca=getchar(); while(read_ca>='0'&&read_ca<='9') read_p=read_p*10+read_ca-48,read_ca=getchar(); return read_p; } int _n,n,m,t,e[MN],_e[MN],Mmh=0,D[MN],C_a[MN],C_b[MN],C_c[MN],N_c[MN],D_a[MN],D_b[MN],D_c[MN],tot,gg=2,MMH[MN],L[MN*5]; int rt[MN*40],B[MN*40],sz=0; int MOD=104857601; inline void M(int &x){while(x>=MOD)x-=MOD;} inline int mi(int a,int b){ int mmh=1; while (b){ if (b&1) mmh=1LL*mmh*a%MOD; b>>=1;a=1LL*a*a%MOD; } return mmh; } inline void inv(){ int base=mi(gg,(MOD-1)/tot),_base=mi(base,MOD-2); e[0]=_e[0]=1; for (register int i=1;i<=tot;i++) e[i]=1LL*e[i-1]*base%MOD,_e[i]=1LL*_e[i-1]*_base%MOD; } inline void NTT(int N,int a[],int w[]){ register int i,j,k,m,z; for (i=j=0;i<N;i++){ if (i>j) swap(a[i],a[j]); for (k=N>>1;(j^=k)<k;k>>=1); } for (i=2;i<=N;i<<=1) for (m=i>>1,j=0;j<N;j+=i) for (k=0;k<m;k++){ z=1LL*a[j+k+m]*w[tot/i*k]%MOD; a[j+k+m]=a[j+k]>z?a[j+k]-z:MOD-z+a[j+k]; a[j+k]=a[j+k]-MOD+z;if (a[j+k]<0) a[j+k]+=MOD; } } inline void cc(int N,int a[],int b[],int c[]){ memcpy(C_a,a,N<<2);memcpy(C_b,b,N<<2); NTT(N,C_a,e);NTT(N,C_b,e); for (register int i=0;i<N;i++) c[i]=1LL*C_a[i]*C_b[i]%MOD; NTT(N,c,_e); int w=mi(N,MOD-2); for (register int i=0;i<N;i++) c[i]=1LL*c[i]*w%MOD; } inline void cc(int n,int m,int a[],int b[],int c[]){ int N; for (N=1;N<(n+m);N<<=1); memcpy(C_a,a,n<<2);memcpy(C_b,b,m<<2); fill(C_a+n,C_a+N,0);fill(C_b+m,C_b+N,0); NTT(N,C_a,e);NTT(N,C_b,e); for (register int i=0;i<N;i++) c[i]=1LL*C_a[i]*C_b[i]%MOD; NTT(N,c,_e); int w=mi(N,MOD-2); for (register int i=0;i<N;i++) c[i]=1LL*c[i]*w%MOD; } inline void ny(int p,int a[],int b[]){ if (p==1) b[0]=mi(a[0],MOD-2);else{ ny((p+1)>>1,a,b); int N=1; while (N<(p<<1))N<<=1; copy(a,a+p,N_c);fill(N_c+p,N_c+N,0); NTT(N,N_c,e);NTT(N,b,e); for (register int i=0;i<N;i++) b[i]=(2LL-1LL*N_c[i]*b[i]%MOD+MOD)*b[i]%MOD; NTT(N,b,_e); int w=mi(N,MOD-2); for (register int i=0;i<N;i++) b[i]=1LL*b[i]*w%MOD; fill(b+p,b+N,0); } } inline void re_copy(int n,int a[],int b[]){for (register int i=0;i<n;i++) b[i]=a[n-i-1];} inline void div(int n,int m,int a[],int b[],int d[],int r[]){ int N=1,t=n-m+1,i; while (N<t<<1)N<<=1; memset(D_a,0,N<<2); memset(D_b,0,N<<2); memset(D_c,0,N<<2); memset(d,0,N<<2); re_copy(m,b,D_b); re_copy(n,a,D_a); ny(t,D_b,D_c); for (N=1;N<(n<<1);N<<=1); cc(n,t,D_a,D_c,D_b); re_copy(t,D_b,d); fill(d+t,d+N,0); cc(t,m,d,b,D_a); for (i=0;i<m;i++) r[i]=(1LL*a[i]-D_a[i]+MOD)%MOD; fill(r+m,r+N,0); } inline bool ju(int x){ int u=MOD-1; for (register int i=2;i*i<=u;i++) if (u%i==0) if (mi(x,u/i)==1) return 1; return 0; } int mmh=1; inline void Mmhp(int p,int l,int r){ if (l==r){ L[p]=sz; rt[sz]=l; rt[sz+1]=1; sz+=2; return; } int mid=l+r>>1; Mmhp(lp,l,mid);Mmhp(rp,mid+1,r); cc(mid-l+2,r-mid+1,rt+L[lp],rt+L[rp],rt+sz); L[p]=sz; sz+=r-l+2; } inline void Mmhrt(int p,int l,int r){ if (l==r){ L[p]=sz; rt[sz]=(MOD-1LL*m*l%MOD)%MOD; rt[sz+1]=1; sz+=2; return; } int mid=l+r>>1; Mmhrt(lp,l,mid);Mmhrt(rp,mid+1,r); cc(mid-l+2,r-mid+1,rt+L[lp],rt+L[rp],rt+sz); L[p]=sz; sz+=r-l+2; } inline void _Mmh(int p,int l,int r,int fi,int LL){ div(LL,r-l+2,B+fi,rt+L[p],D,B+sz); int mid=l+r>>1,s=sz; sz+=r-l+2; if (l==r) mmh=1LL*B[s]*mmh%MOD;else _Mmh(lp,l,mid,s,r-l+1),_Mmh(rp,mid+1,r,s,r-l+1); } int main(){ register int i; n=read(); MOD=read(); if (n>=MOD) return printf("0\n"),0; while(ju(gg))gg++; for (m=1;(m+1)*(m+1)<=n;m++); for (tot=1;tot<((m+2)<<1);tot<<=1);inv(); for (i=m*m+1;i<=n;i++) mmh=1LL*mmh*i%MOD; sz=0;Mmhp(1,1,m); for (i=L[1];i<=L[1]+m;i++) B[i-L[1]]=rt[i]; sz=0;Mmhrt(1,0,m-1); sz=m+1;_Mmh(1,0,m-1,0,m+1); if (n&1) mmh=1LL*mmh*mi(2,MOD-2)%MOD; printf("%d\n",mmh); }