NTT
重新改了一下板子 抄起来比较短
#include <bits/stdc++.h> using namespace std; #define rep(i,h,t) for (int i=h;i<=t;i++) #define dep(i,t,h) for (int i=t;i>=h;i--) #define ll long long #define IL inline const int mo=998244353; const int G=3; const int N=4e5; const int mxl=(1<<18); struct NTT{ int n,m,wn[N],a[N],b[N],C[N],D[N],inv[N]; int fsp(int x,int y) {int ans;for (ans=1;y;y>>=1,x=1ll*x*x%mo) if (y&1) ans=1ll*ans*x%mo; return ans;} void pre(){ inv[0]=inv[1]=1; rep(i,2,N-1) inv[i]=(1ll*inv[mo%i]*(mo-(mo/i)))%mo; int x=fsp(G,(mo-1)/mxl); wn[mxl>>1]=1; rep(i,(mxl>>1)+1,mxl-1) wn[i]=1ll*wn[i-1]*x%mo; dep(i,(mxl>>1)-1,1) wn[i]=wn[i<<1]; } inline int add(int x,int y) {return x+y>=mo?x+y-mo:x+y;} void clear() { rep(i,0,n) a[i]=b[i]=0; } void ntt(int *a,int f) { if (f>0){ for (int k=n>>1;k;k>>=1) for (int i=0;i<n;i+=k<<1) for (int j=0;j<k;j++){ int x=a[i+j],y=a[i+j+k]; a[i+j+k]=1ll*(x-y+mo)*wn[k+j]%mo; a[i+j]=add(x,y); } } else{ for (int k=1;k<n;k<<=1) for (int i=0;i<n;i+=(k<<1)) for (int j=0;j<k;j++){ int x=a[i+j],y=1ll*a[i+j+k]*wn[k+j]%mo; a[i+j+k]=add(x,mo-y); a[i+j]=add(x,y); } for (int i=0,inv=mo-(mo-1)/n;i<n;i++) a[i]=1ll*a[i]*inv%mo; reverse(a+1,a+n); } } IL void js(int *A,int *B,int len) { for (n=1;n<len;n*=2); for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; ntt(a,1); ntt(b,1); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; ntt(a,-1); for (int i=0;i<len;i++) B[i]=a[i]; clear(); } IL void js(int *A,int len) { for (n=1;n<len;n*=2); ntt(a,1); ntt(b,1); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; ntt(a,-1); for (int i=0;i<len;i++) A[i]=a[i]; clear(); } IL void getcj(int *A,int *B,int len) { for (n=1;n<2*len;n*=2); for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; ntt(a,1); ntt(b,1); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo; ntt(a,-1); for (int i=0;i<len;i++) B[i]=a[i]; clear(); } IL void getinv(int *A,int *B,int len) { if (len==1) {B[0]=fsp(A[0],mo-2); return;} getinv(A,B,(len+1)>>1); for (n=1;n<=len*2;n*=2); for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; ntt(a,1); ntt(b,1); for (int i=0;i<n;i++) a[i]=1ll*a[i]*b[i]%mo*b[i]%mo; ntt(a,-1); for (int i=0;i<len;i++) B[i]=((2*B[i]-a[i])%mo+mo)%mo; clear(); } IL void getsqrt(int *A,int *B,int len) { int inv2=fsp(2,mo-2); if (len==1) { B[0]=sqrt(A[0]); return;} getsqrt(A,B,(len+1)>>1); int C[N]; getinv(B,C,len); getcj(A,C,len); for (int i=0;i<n;i++) B[i]=1ll*(C[i]+B[i])%mo*inv2%mo; } IL void getDao(int *a,int *b,int len) { for (int i=1;i<len;i++) b[i-1]=1ll*i*a[i]%mo; b[len-1]=0; } IL void getjf(int *a,int *b,int len) { for (int i=0;i<=len;i++) b[i+1]=1ll*a[i]*inv[i+1]%mo; b[0]=0; } IL void getln(int *A,int *B,int len) { getDao(A,C,len); getinv(A,D,len); getcj(C,D,len); getjf(D,B,len); rep(i,0,len+1) C[i]=0,D[i]=0; } IL void getexp(int *A,int *B,int len) { if (len==1) {B[0]=1; return;} getexp(A,B,(len+1)>>1); int C[N]; getln(B,C,len); for(int i=0;i<len;i++) C[i]=((-C[i]+A[i])%mo+mo)%mo; C[0]=(C[0]+1)%mo; getcj(C,B,len); } }F; int now[N],f[N],g[N]; void solve(int h,int t) { if (h>=t) return; int mid=(h+t)/2; solve(h,mid); rep(i,h,mid) F.a[i-h]=f[i]; rep(i,1,t-h) F.b[i]=g[i]; F.js(now,(mid-h+1)+(t-h+1)); rep(i,mid+1,t) f[i]=(f[i]+now[i-h])%mo; solve(mid+1,t); } int tmpa[N],tmpb[N]; vector<int> operator *(const vector<int> &t1,const vector<int> &t2) { int n1=t1.size(),n2=t2.size(); rep(i,0,n1-1) tmpa[i]=t1[i]; rep(i,0,n2-1) tmpb[i]=t2[i]; F.js(tmpa,tmpb,n1+n2-1); vector<int> c; rep(i,0,n1+n2-2) c.push_back(tmpb[i]); rep(i,0,n1+n2) tmpa[i]=tmpb[i]=0; return c; } void operator += (vector<int> &a,const vector<int> &b){ if (b.size()>a.size()) a.resize(b.size()); for (int i=0,si=b.size();i<si;i++) a[i]=(a[i]+b[i])%mo; } struct re{ vector<int> a[2][2]; }M[N]; #define mid ((h+t)/2) re solve1(int h,int t) { if (h==t) return M[h]; re z; re x=solve1(h,mid),y=solve1(mid+1,t); rep(i,0,1) rep(j,0,1) rep(k,0,1) { z.a[i][k]+=x.a[i][j]*y.a[j][k]; } return z; } int a[N],b[N]; vector<int> ve[N]; #define vii vector<int> vii solve2(int h,int t) { if (h==t) return ve[h]; return solve2(h,mid)*solve2(mid+1,t); } int main() { F.pre(); return 0; } // /* 3 2 1 2 2 1 2 2 1 2 */
大模数NTT
#include <bits/stdc++.h> using namespace std; #define rep(i,h,t) for (int i=h;i<=t;i++) #define dep(i,t,h) for (int i=t;i>=h;i--) #define ll long long #define IL inline //const int mo=998244353; const ll mo=31525197391593473; const int G=3; const int N=4e5; const int mxl=(1<<18); struct NTT{ ll n,m,a[N],b[N]; __int128 wn[N]; ll fsp(__int128 x,ll y) {__int128 ans;for (ans=1;y;y>>=1,x=x*x%mo) if (y&1) ans=ans*x%mo; return ans;} void pre(){ ll x=fsp(G,(mo-1)/mxl); wn[mxl>>1]=1; rep(i,(mxl>>1)+1,mxl-1) wn[i]=wn[i-1]*x%mo; dep(i,(mxl>>1)-1,1) wn[i]=wn[i<<1]; } inline ll add(ll x,ll y) {return x+y>=mo?x+y-mo:x+y;} void clear() { rep(i,0,n) a[i]=b[i]=0; } void ntt(ll *a,int f) { if (f>0){ for (int k=n>>1;k;k>>=1) for (int i=0;i<n;i+=k<<1) for (int j=0;j<k;j++){ __int128 x=a[i+j],y=a[i+j+k]; a[i+j+k]=(x-y+mo)*wn[k+j]%mo; a[i+j]=add(x,y); } } else{ for (int k=1;k<n;k<<=1) for (int i=0;i<n;i+=(k<<1)) for (int j=0;j<k;j++){ __int128 x=a[i+j],y=a[i+j+k]*wn[k+j]%mo; a[i+j+k]=add(x,mo-y); a[i+j]=add(x,y); } for (ll i=0,inv=mo-(mo-1)/n;i<n;i++) a[i]=(__int128)a[i]*inv%mo; reverse(a+1,a+n); } } IL void js(ll *A,ll *B,int len) { for (n=1;n<len;n*=2); for (int i=0;i<len;i++) a[i]=(A[i]%mo+mo)%mo,b[i]=(B[i]%mo+mo)%mo; ntt(a,1); ntt(b,1); for (int i=0;i<n;i++) a[i]=(__int128)a[i]*b[i]%mo; ntt(a,-1); for (int i=0;i<len;i++) B[i]=a[i]; clear(); } }F; int main() { F.pre(); return 0; }