多项式全家桶
包括NTT模数和非NTT模数。
如果有锅/可以卡常的地方欢迎评论区指出,会在注释里鸣谢
UPD on 2020/5/24 11:??:加上了多项式快速幂。
UPD on 2020/5/24 19:08:加上了多项式除法并修了快速幂的一个锅。
UPD on 2020/5/30 00:00:修了几个锅。
UPD on 2020/7/15 20:30:修改了数组大小以免溢出,并同时将inv[0]和inv[1]的初始化移到了prep函数里面。
UPD on 2020/7/16 23:15:修改了prep函数,这样可以返回lmt的值,并修改了排版。
UPD on 2021/1/5 20:40:改了若干bug。
UPD on 2021/1/6 22:40:新增 CZT。
UPD on 2021/1/8 21:10:新增 FDT(下降幂多项式乘法),同时修改了 prep 函数,从而预计算出阶乘和阶乘逆元。
UPD on 2021/1/14 13:35:新增 多点求值 并修改了 多项式除法。
UPD on 2021/1/14 16:09:新增了 普通多项式转下降幂多项式 并对原先的 多点求值 进行卡常。
UPD on 2021/1/14 23:30:新增了多项式快速插值。
求评论区提一些有效的建议。
P.S. 由于一些不可抗力(部分缩进是 4 个空格,部分是 tab),直接食用会造成不适,请复制到 tab 长度为 4 的环境下使用。
Code(巨长代码警告)
#ifndef __POLY_H__
#define __POLY_H__
#include<bits/stdc++.h>
#define clear(a) memset((a),0,len<<5)
using namespace std;
typedef long long ll;
const ll N=1048576,P=998244353;
const long double Pi=acos(-1.0);
ll inv[N],fac[N],invfac[N];
namespace Poly{//模数为NTT模数
const ll G=3,img=86583718;
ll lmt,rev[N],a[N],b[N],c[N],d[N],e[N],h[N],x[N],y[N],z[N],X[N],Y[N],ff[N],gg[N],iv[N],t[N];//poly1
ll A[N],B[N],ee[N],Len[N],*p[N],C[N],v[N],*D[N],E[N];//poly2
inline ll qpow(ll a,ll k){
ll ret=1;
while(k){
if(k&1)ret=ret*a%P;
a=(a*a)%P;
k>>=1;
}
return ret%P;
}
inline void init(ll n){
lmt=1;ll t=0;
while(lmt<n)lmt<<=1,t++;
for(ll i=1;i<lmt;i++)rev[i]=(rev[i>>1]>>1)|(i&1)<<(t-1);
}
inline void NTT(ll *A,ll lmt,ll tp){
for(ll i=0;i<lmt;i++)if(i<rev[i])swap(A[i],A[rev[i]]);
for(ll m=1;m<lmt;m<<=1)
for(ll j=0,Wn=qpow(G,(P-1)/(m<<1));j<lmt;j+=m<<1)
for(ll k=0,w=1,x,y;k<m;k++,w=w*Wn%P)
x=A[j+k],y=w*A[j+k+m]%P,A[j+k]=(x+y)%P,A[j+k+m]=(x-y+P)%P;
if(tp==1)return;
reverse(A+1,A+lmt);
for(ll i=0,inv=qpow(lmt,P-2);i<=lmt;i++)A[i]=A[i]*inv%P;
}
inline void mul(ll *f,ll *g,ll len){
init(len);
NTT(f,lmt,1);NTT(g,lmt,1);
for(ll i=0;i<lmt;i++)f[i]=(f[i]*g[i])%P;
NTT(f,lmt,-1);
}
void getinv(ll*f,ll*g,ll len){
if(len==1){g[0]=qpow(f[0],P-2);return;}
getinv(f,g,len+1>>1);
init(len<<1);
for(ll i=0;i<len;i++)c[i]=f[i];
for(ll i=len;i<lmt;i++)c[i]=0;
NTT(c,lmt,1),NTT(g,lmt,1);
for(ll i=0;i<lmt;i++)g[i]=(2LL-g[i]*c[i]%P+P)%P*g[i]%P;
NTT(g,lmt,-1);
for(ll i=len;i<lmt;i++)g[i]=0;
clear(c);
}
inline void div(ll *f,ll *g,ll *q,ll *r,ll n,ll m){
for(ll i=0,t=n-1;i<n;i++,t--)ff[i]=f[t];
for(ll i=0,t=m-1;i<m;i++,t--)gg[i]=g[t];
ll len=n-m+1;
for(ll i=len;i<n;i++)ff[i]=gg[i]=0;
getinv(gg,iv,len);
mul(ff,iv,len<<1);
for(ll i=0,t=len-1;i<len;i++)q[i]=ff[t--];
for(ll i=len;i<n;i++)q[i]=0;
for(ll i=0;i<n;i++)t[i]=q[i];
len=n;
clear(gg);
for(ll i=0;i<m;i++)gg[i]=g[i];
mul(t,gg,n<<1);
for(ll i=0;i<m-1;i++)r[i]=(f[i]-t[i]+P)%P;
clear(ff),clear(gg),clear(iv),clear(t);
}
inline void getdev(ll*f,ll*g,ll len){
for(ll i=1;i<len;i++)g[i-1]=i*f[i]%P;
g[len-1]=g[len]=0;
}
inline void getinvdev(ll*f,ll*g,ll len){
for(ll i=1;i<=len;i++)g[i]=f[i-1]*inv[i]%P;
g[0]=0;
}
inline void getln(ll*f,ll*g,ll len){
getdev(f,a,len);
getinv(f,b,len);
mul(a,b,len<<1);
getinvdev(a,g,len);
clear(a),clear(b);
}
void getexp(ll*f,ll*g,ll len){
if(len==1){g[0]=1;return;}
getexp(f,g,len+1>>1);
init(len<<1);
for(ll i=0;i<(len<<1);i++)d[i]=e[i]=0;
getln(g,d,len);
for(ll i=0;i<len;i++)e[i]=f[i];
NTT(g,lmt,1),NTT(d,lmt,1),NTT(e,lmt,1);
for(ll i=0;i<lmt;i++)g[i]=(1-d[i]+e[i]+P)*g[i]%P;
NTT(g,lmt,-1);
for(ll i=len;i<lmt;i++)g[i]=0;
clear(d),clear(e);
}
void getpow(ll*f,ll*g,ll len,ll k){
getln(f,h,len);
for(ll i=0;i<len;i++)h[i]=h[i]*k%P;
getexp(h,g,len);
clear(h);
}
inline void getsqrt(ll*f,ll*g,ll len){
getln(f,h,len);
for(ll i=0;i<len;i++)h[i]=h[i]*inv[2]%P;
getexp(h,g,len);
clear(h);
}
void sin(ll*f,ll*g,ll len){
for(ll i=0;i<len;i++)x[i]=img*f[i]%P;
getexp(x,X,len),getinv(X,Y,len);
for(ll i=0;i<len;i++)g[i]=(X[i]-Y[i]+P)%P*qpow(img<<1,P-2)%P;
clear(x),clear(X),clear(Y);
}
void cos(ll*f,ll*g,ll len){
for(ll i=0;i<len;i++)x[i]=img*f[i]%P;
getexp(x,X,len),getinv(X,Y,len);
for(ll i=0;i<len;i++)g[i]=(X[i]+Y[i])%P*inv[2]%P;
clear(x),clear(X),clear(Y);
}
inline void arcsin(ll*f,ll*g,ll len){
getdev(f,x,len);
init(len<<1);
NTT(f,lmt,1);
for(ll i=0;i<lmt;i++)y[i]=(1+P-f[i]*f[i]%P)%P;
NTT(y,lmt,-1);
for(ll i=len;i<lmt;i++)y[i]=0;
getsqrt(y,z,len);
memset(y,0,(len+1)<<3);
getinv(z,y,len);
NTT(x,lmt,1),NTT(y,lmt,1);
for(ll i=0;i<lmt;i++)x[i]=x[i]*y[i]%P;
NTT(x,lmt,-1);
getinvdev(x,g,len);
clear(x),clear(y),clear(z);
}
inline void arctan(ll*f,ll*g,ll len){
getdev(f,x,len);
init(len<<1);
NTT(f,lmt,1);
for(ll i=0;i<lmt;i++)y[i]=(1+f[i]*f[i]%P)%P;
NTT(y,lmt,-1);
for(ll i=len;i<lmt;i++)y[i]=0;
getinv(y,z,len);
NTT(x,lmt,1),NTT(z,lmt,1);
for(ll i=0;i<lmt;i++)x[i]=x[i]*z[i]%P;
NTT(x,lmt,-1);
getinvdev(x,g,len);
clear(x),clear(y),clear(z);
}
inline ll F(ll x){return x*(x-1)/2%(P-1);}
inline void CZT(ll *f,ll *g,ll len,ll c,ll m){
for(ll i=0;i<len;i++)A[i]=qpow(c,P-1-F(i))*f[i]%P;
for(ll i=0;i<len+m;i++)B[i]=qpow(c,F(i));
reverse(A,A+len);
mul(A,B,len*2+m);
for(ll i=0;i<m;i++)g[i]=qpow(c,P-1-F(i))*A[i+len-1]%P;
clear(A),clear(B);
}
void FDT(ll *A,ll len,ll tp){
init(len<<1);
if(tp==-1)for(ll i=0;i<lmt;i++)A[i]=A[i]*invfac[i]%P;
for(ll i=0;i<len;i++){
if(tp==-1&&i&1)ee[i]=P-invfac[i];
else ee[i]=invfac[i];
}
for(ll i=len;i<lmt;i++)ee[i]=A[i]=0;
NTT(A,lmt,1);NTT(ee,lmt,1);
for(ll i=0;i<lmt;i++)A[i]=A[i]*ee[i]%P;
NTT(A,lmt,-1);
if(tp==1)for(ll i=0;i<lmt;i++)A[i]=A[i]*fac[i]%P;
for(ll i=0;i<lmt;i++)ee[i]=0;
}
inline void mulDown(ll *f,ll *g,ll len){
FDT(f,len,1);FDT(g,len,1);
for(ll i=0;i<len;i++)f[i]=f[i]*g[i]%P;
FDT(f,len,-1);
}
void getP(const ll *a,ll k,ll l,ll r){
if(l==r){
Len[k]=1;
p[k]=new ll[2];
p[k][0]=P-a[l];
p[k][1]=1;
return;
}
ll mid=l+r>>1;
getP(a,k<<1,l,mid);
getP(a,k<<1|1,mid+1,r);
Len[k]=Len[k<<1]+Len[k<<1|1];
p[k]=new ll[Len[k]+1];
init(Len[k]+1<<1);
static ll A[N],B[N];
for(ll i=0;i<=Len[k<<1];i++)A[i]=p[k<<1][i];
for(ll i=Len[k<<1]+1;i<lmt;i++)A[i]=0;
for(ll i=0;i<=Len[k<<1|1];i++)B[i]=p[k<<1|1][i];
for(ll i=Len[k<<1|1]+1;i<lmt;i++)B[i]=0;
NTT(A,lmt,1);NTT(B,lmt,1);
for(ll i=0;i<lmt;i++)A[i]=A[i]*B[i]%P;
NTT(A,lmt,-1);
for(ll i=0;i<=Len[k];i++)p[k][i]=A[i];
}
void solve(ll k,ll l,ll r,const ll *a,ll *A,ll *ans){
if(Len[k]<=500){
ll m=Len[k]-1;
for(ll i=l;i<=r;i++)
for(ll j=m;j>=0;j--)
ans[i]=(ans[i]*a[i]+A[j])%P;
return;
}
if(l==r){ans[l]=*A;return;}
ll mid=l+r>>1,R[Len[k]+2>>1];
static ll t[N];
div(A,p[k<<1],t,R,Len[k],Len[k<<1]+1);
solve(k<<1,l,mid,a,R,ans);
div(A,p[k<<1|1],t,R,Len[k],Len[k<<1|1]+1);
solve(k<<1|1,mid+1,r,a,R,ans);
}
void evaluation(ll *f,ll *a,ll *ans,ll n,ll m){
getP(a,1,1,m);
if(n>m){
static ll t[N];
div(f,p[1],t,f,n,m+1);
}
solve(1,1,m,a,f,ans);
}
void solve(ll k,ll l,ll r,const ll *x){
if(l==r){
D[k]=new ll[1];
D[k][0]=v[l];
return;
}
ll mid=l+r>>1;
solve(k<<1,l,mid,x);
solve(k<<1|1,mid+1,r,x);
D[k]=new ll[Len[k]];
init(Len[k]);
static ll f1[N],f2[N],p1[N],p2[N];
for(ll i=0;i<Len[k<<1];i++)f1[i]=D[k<<1][i];
for(ll i=Len[k<<1];i<lmt;i++)f1[i]=0;
for(ll i=0;i<Len[k<<1|1];i++)f2[i]=D[k<<1|1][i];
for(ll i=Len[k<<1|1];i<lmt;i++)f2[i]=0;
for(ll i=0;i<=Len[k<<1];i++)p1[i]=p[k<<1][i];
for(ll i=Len[k<<1]+1;i<lmt;i++)p1[i]=0;
for(ll i=0;i<=Len[k<<1|1];i++)p2[i]=p[k<<1|1][i];
for(ll i=Len[k<<1|1]+1;i<lmt;i++)p2[i]=0;
mul(f1,p2,Len[k]);
mul(f2,p1,Len[k]);
for(ll i=0;i<Len[k];i++)D[k][i]=(f1[i]+f2[i])%P;
}
void interpolation(ll *x,ll *y,ll *f,ll n){
ll len=n;
getP(x,1,1,n);
getdev(p[1],C,n+1);
solve(1,1,n,x,C,v);
for(ll i=1;i<=n;i++)v[i]=y[i]*qpow(v[i],P-2)%P;
solve(1,1,n,x);
for(ll i=0;i<n;i++)f[i]=D[1][i];
clear(v);
}
void polytoffp(ll *f,ll *g,ll len){
for(ll i=1;i<=len;i++)E[i]=i-1;
clear(g);
evaluation(f,E,g,len,len);
for(ll i=0;i<len;i++)g[i]=g[i+1]*invfac[i],E[i]=(i&1?P-invfac[i]:invfac[i]);
E[len]=g[len]=0;
mul(g,E,len<<1);
clear(E);
}
}
ll prep(ll n){
ll lmt=1;
inv[0]=inv[1]=1;
while(lmt<n)lmt<<=1;
for(ll i=2;i<lmt;i++)inv[i]=(P-P/i)*inv[P%i]%P;
fac[0]=invfac[0]=1;
for(ll i=1;i<lmt;i++)fac[i]=fac[i-1]*i%P,invfac[i]=invfac[i-1]*inv[i]%P;
return lmt;
}
namespace Poly2{//模数不是NTT模数
int lmt,rev[N];
struct comp{
long double x,y;
comp(long double a=0,long double b=0){x=a,y=b;}
}a[N],b[N],c[N],d[N];
comp operator+(comp a,comp b){return comp(a.x+b.x,a.y+b.y);}
comp operator-(comp a,comp b){return comp(a.x-b.x,a.y-b.y);}
comp operator*(comp a,comp b){return comp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
comp operator/(comp a,int t){return comp(a.x/t,a.y/t);}
inline void init(int n){
lmt=1;int t=0;
while(lmt<n)lmt<<=1,t++;
for(int i=1;i<lmt;i++)rev[i]=(rev[i>>1]>>1)|(i&1)<<(t-1);
}
inline void FFT(comp*A,int lmt,int tp){
for(int i=0;i<lmt;i++)if(i<rev[i])swap(A[i],A[rev[i]]);
for(int mid=1;mid<lmt;mid<<=1){
comp Wn(cos(Pi/mid),tp*sin(Pi/mid));
for(int R=mid<<1,j=0;j<lmt;j+=R){
comp w(1,0);
for(int k=0;k<mid;k++,w=w*Wn){
comp x=A[j+k],y=w*A[j+k+mid];
A[j+k]=x+y,A[j+k+mid]=x-y;
}
}
}
}
void MTT(int*f,int*g,int*ans,int n,int m){
init(n+m);
const int lim=(1<<15)-1;
for(int i=0;i<n;i++)a[i]=comp(f[i]&lim,f[i]>>15);
for(int i=n;i<lmt;i++)a[i]=comp();
for(int i=0;i<m;i++)b[i]=comp(g[i]&lim,g[i]>>15);
for(int i=m;i<lmt;i++)b[i]=comp();
FFT(a,lmt,1),FFT(b,lmt,1);
for(int i=0;i<lmt;i++){
int t=(lmt-i)&(lmt-1);
c[i]=comp((a[i].x+a[t].x)*0.5,(a[i].y-a[t].y)*0.5)*b[i];
d[i]=comp((a[i].y+a[t].y)*0.5,(a[t].x-a[i].x)*0.5)*b[i];
}
FFT(c,lmt,-1),FFT(d,lmt,-1);
for(int i=0;i<lmt;i++)c[i]=c[i]/lmt,d[i]=d[i]/lmt;
for(int i=0;i<lmt;i++){
ll p=c[i].x+0.5,o=c[i].y+0.5,x=d[i].x+0.5,u=d[i].y+0.5;
ans[i]=(p%P+((o+x)%P<<15)+(u%P<<30))%P;
}
}
}
#endif