多项式全家桶

包括NTT模数和非NTT模数。

如果有锅/可以卡常的地方欢迎评论区指出,会在注释里鸣谢

UPD on 2020/5/24 11:??:加上了多项式快速幂。

UPD on 2020/5/24 19:08:加上了多项式除法并修了快速幂的一个锅。

UPD on 2020/5/30 00:00:修了几个锅。

UPD on 2020/7/15 20:30:修改了数组大小以免溢出,并同时将inv[0]和inv[1]的初始化移到了prep函数里面。

UPD on 2020/7/16 23:15:修改了prep函数,这样可以返回lmt的值,并修改了排版。

UPD on 2021/1/5 20:40:改了若干bug。

UPD on 2021/1/6 22:40:新增 CZT。

UPD on 2021/1/8 21:10:新增 FDT(下降幂多项式乘法),同时修改了 prep 函数,从而预计算出阶乘和阶乘逆元。

UPD on 2021/1/14 13:35:新增 多点求值 并修改了 多项式除法。

UPD on 2021/1/14 16:09:新增了 普通多项式转下降幂多项式 并对原先的 多点求值 进行卡常。

UPD on 2021/1/14 23:30:新增了多项式快速插值。

求评论区提一些有效的建议。

P.S. 由于一些不可抗力(部分缩进是 4 个空格,部分是 tab),直接食用会造成不适,请复制到 tab 长度为 4 的环境下使用。

Code(巨长代码警告
#ifndef __POLY_H__
#define __POLY_H__
#include<bits/stdc++.h>
#define clear(a) memset((a),0,len<<5)
using namespace std;
typedef long long ll;
const ll N=1048576,P=998244353;
const long double Pi=acos(-1.0);
ll inv[N],fac[N],invfac[N];
namespace Poly{//模数为NTT模数 
    const ll G=3,img=86583718;
    ll lmt,rev[N],a[N],b[N],c[N],d[N],e[N],h[N],x[N],y[N],z[N],X[N],Y[N],ff[N],gg[N],iv[N],t[N];//poly1
	ll A[N],B[N],ee[N],Len[N],*p[N],C[N],v[N],*D[N],E[N];//poly2
    inline ll qpow(ll a,ll k){
        ll ret=1;
        while(k){
            if(k&1)ret=ret*a%P;
            a=(a*a)%P;
            k>>=1;
        }
        return ret%P;
    }
    inline void init(ll n){
        lmt=1;ll t=0;
        while(lmt<n)lmt<<=1,t++;
        for(ll i=1;i<lmt;i++)rev[i]=(rev[i>>1]>>1)|(i&1)<<(t-1);
    }
    inline void NTT(ll *A,ll lmt,ll tp){
        for(ll i=0;i<lmt;i++)if(i<rev[i])swap(A[i],A[rev[i]]);
        for(ll m=1;m<lmt;m<<=1)
            for(ll j=0,Wn=qpow(G,(P-1)/(m<<1));j<lmt;j+=m<<1)
                for(ll k=0,w=1,x,y;k<m;k++,w=w*Wn%P)
                    x=A[j+k],y=w*A[j+k+m]%P,A[j+k]=(x+y)%P,A[j+k+m]=(x-y+P)%P;
        if(tp==1)return;
        reverse(A+1,A+lmt);
        for(ll i=0,inv=qpow(lmt,P-2);i<=lmt;i++)A[i]=A[i]*inv%P;
    } 
    inline void mul(ll *f,ll *g,ll len){
        init(len);
        NTT(f,lmt,1);NTT(g,lmt,1);
        for(ll i=0;i<lmt;i++)f[i]=(f[i]*g[i])%P;
        NTT(f,lmt,-1);
    } 
    void getinv(ll*f,ll*g,ll len){
        if(len==1){g[0]=qpow(f[0],P-2);return;}
        getinv(f,g,len+1>>1);
        init(len<<1);
        for(ll i=0;i<len;i++)c[i]=f[i];
        for(ll i=len;i<lmt;i++)c[i]=0;
        NTT(c,lmt,1),NTT(g,lmt,1);
        for(ll i=0;i<lmt;i++)g[i]=(2LL-g[i]*c[i]%P+P)%P*g[i]%P;
        NTT(g,lmt,-1);
        for(ll i=len;i<lmt;i++)g[i]=0; 
        clear(c);
    }
    inline void div(ll *f,ll *g,ll *q,ll *r,ll n,ll m){
        for(ll i=0,t=n-1;i<n;i++,t--)ff[i]=f[t];
        for(ll i=0,t=m-1;i<m;i++,t--)gg[i]=g[t];
        ll len=n-m+1;
        for(ll i=len;i<n;i++)ff[i]=gg[i]=0;
        getinv(gg,iv,len);
        mul(ff,iv,len<<1);
        for(ll i=0,t=len-1;i<len;i++)q[i]=ff[t--];
        for(ll i=len;i<n;i++)q[i]=0;
        for(ll i=0;i<n;i++)t[i]=q[i];
        len=n;
        clear(gg);
        for(ll i=0;i<m;i++)gg[i]=g[i];
        mul(t,gg,n<<1);
        for(ll i=0;i<m-1;i++)r[i]=(f[i]-t[i]+P)%P;
        clear(ff),clear(gg),clear(iv),clear(t);
    }
    inline void getdev(ll*f,ll*g,ll len){
        for(ll i=1;i<len;i++)g[i-1]=i*f[i]%P;
        g[len-1]=g[len]=0;
    }
    inline void getinvdev(ll*f,ll*g,ll len){
        for(ll i=1;i<=len;i++)g[i]=f[i-1]*inv[i]%P;
        g[0]=0;
    }
    inline void getln(ll*f,ll*g,ll len){
        getdev(f,a,len);
        getinv(f,b,len);
        mul(a,b,len<<1);
        getinvdev(a,g,len);
        clear(a),clear(b);
    }
    void getexp(ll*f,ll*g,ll len){
        if(len==1){g[0]=1;return;}
        getexp(f,g,len+1>>1);
        init(len<<1);
        for(ll i=0;i<(len<<1);i++)d[i]=e[i]=0;
        getln(g,d,len);
        for(ll i=0;i<len;i++)e[i]=f[i];
        NTT(g,lmt,1),NTT(d,lmt,1),NTT(e,lmt,1);
        for(ll i=0;i<lmt;i++)g[i]=(1-d[i]+e[i]+P)*g[i]%P;
        NTT(g,lmt,-1);
        for(ll i=len;i<lmt;i++)g[i]=0; 
        clear(d),clear(e);
    }
	void getpow(ll*f,ll*g,ll len,ll k){
        getln(f,h,len);
        for(ll i=0;i<len;i++)h[i]=h[i]*k%P;
        getexp(h,g,len);
        clear(h);
    }
    inline void getsqrt(ll*f,ll*g,ll len){
        getln(f,h,len);
        for(ll i=0;i<len;i++)h[i]=h[i]*inv[2]%P;
        getexp(h,g,len);
        clear(h);
    }
    void sin(ll*f,ll*g,ll len){
        for(ll i=0;i<len;i++)x[i]=img*f[i]%P;
        getexp(x,X,len),getinv(X,Y,len);
        for(ll i=0;i<len;i++)g[i]=(X[i]-Y[i]+P)%P*qpow(img<<1,P-2)%P;
        clear(x),clear(X),clear(Y);
    }
    void cos(ll*f,ll*g,ll len){
        for(ll i=0;i<len;i++)x[i]=img*f[i]%P;
        getexp(x,X,len),getinv(X,Y,len);
        for(ll i=0;i<len;i++)g[i]=(X[i]+Y[i])%P*inv[2]%P;
        clear(x),clear(X),clear(Y);
    } 
    inline void arcsin(ll*f,ll*g,ll len){
        getdev(f,x,len);
        init(len<<1);
        NTT(f,lmt,1);
        for(ll i=0;i<lmt;i++)y[i]=(1+P-f[i]*f[i]%P)%P;
        NTT(y,lmt,-1);
        for(ll i=len;i<lmt;i++)y[i]=0;
        getsqrt(y,z,len);
        memset(y,0,(len+1)<<3);
        getinv(z,y,len);
        NTT(x,lmt,1),NTT(y,lmt,1);
        for(ll i=0;i<lmt;i++)x[i]=x[i]*y[i]%P;
        NTT(x,lmt,-1);
        getinvdev(x,g,len);
        clear(x),clear(y),clear(z);
    }
    inline void arctan(ll*f,ll*g,ll len){
        getdev(f,x,len);
        init(len<<1);
        NTT(f,lmt,1);
        for(ll i=0;i<lmt;i++)y[i]=(1+f[i]*f[i]%P)%P;
        NTT(y,lmt,-1);
        for(ll i=len;i<lmt;i++)y[i]=0;
        getinv(y,z,len);
        NTT(x,lmt,1),NTT(z,lmt,1);
        for(ll i=0;i<lmt;i++)x[i]=x[i]*z[i]%P;
        NTT(x,lmt,-1);
        getinvdev(x,g,len);
        clear(x),clear(y),clear(z);
    }
    inline ll F(ll x){return x*(x-1)/2%(P-1);}
	inline void CZT(ll *f,ll *g,ll len,ll c,ll m){
    	for(ll i=0;i<len;i++)A[i]=qpow(c,P-1-F(i))*f[i]%P;
		for(ll i=0;i<len+m;i++)B[i]=qpow(c,F(i));
    	reverse(A,A+len);
    	mul(A,B,len*2+m);
        for(ll i=0;i<m;i++)g[i]=qpow(c,P-1-F(i))*A[i+len-1]%P;
    	clear(A),clear(B);
    }
    void FDT(ll *A,ll len,ll tp){
    	init(len<<1);
    	if(tp==-1)for(ll i=0;i<lmt;i++)A[i]=A[i]*invfac[i]%P;
    	for(ll i=0;i<len;i++){
    		if(tp==-1&&i&1)ee[i]=P-invfac[i];
    		else ee[i]=invfac[i];
		}
		for(ll i=len;i<lmt;i++)ee[i]=A[i]=0;
		NTT(A,lmt,1);NTT(ee,lmt,1);
		for(ll i=0;i<lmt;i++)A[i]=A[i]*ee[i]%P;
		NTT(A,lmt,-1);
		if(tp==1)for(ll i=0;i<lmt;i++)A[i]=A[i]*fac[i]%P;
		for(ll i=0;i<lmt;i++)ee[i]=0;
	}
    inline void mulDown(ll *f,ll *g,ll len){
    	FDT(f,len,1);FDT(g,len,1);
    	for(ll i=0;i<len;i++)f[i]=f[i]*g[i]%P;
    	FDT(f,len,-1);
	}
	void getP(const ll *a,ll k,ll l,ll r){
    	if(l==r){
    		Len[k]=1;
    		p[k]=new ll[2];
    		p[k][0]=P-a[l];
    		p[k][1]=1;
    		return;
		}
		ll mid=l+r>>1;
		getP(a,k<<1,l,mid);
		getP(a,k<<1|1,mid+1,r);
		Len[k]=Len[k<<1]+Len[k<<1|1];
		p[k]=new ll[Len[k]+1];
		init(Len[k]+1<<1);
		static ll A[N],B[N];
		for(ll i=0;i<=Len[k<<1];i++)A[i]=p[k<<1][i];
		for(ll i=Len[k<<1]+1;i<lmt;i++)A[i]=0;
		for(ll i=0;i<=Len[k<<1|1];i++)B[i]=p[k<<1|1][i];
		for(ll i=Len[k<<1|1]+1;i<lmt;i++)B[i]=0;
		NTT(A,lmt,1);NTT(B,lmt,1);
		for(ll i=0;i<lmt;i++)A[i]=A[i]*B[i]%P;
		NTT(A,lmt,-1);
		for(ll i=0;i<=Len[k];i++)p[k][i]=A[i];
	}
	void solve(ll k,ll l,ll r,const ll *a,ll *A,ll *ans){
		if(Len[k]<=500){
			ll m=Len[k]-1;
			for(ll i=l;i<=r;i++)
				for(ll j=m;j>=0;j--)
					ans[i]=(ans[i]*a[i]+A[j])%P;
			return;
		}
		if(l==r){ans[l]=*A;return;}
		ll mid=l+r>>1,R[Len[k]+2>>1];
		static ll t[N];
		div(A,p[k<<1],t,R,Len[k],Len[k<<1]+1);
		solve(k<<1,l,mid,a,R,ans);
		div(A,p[k<<1|1],t,R,Len[k],Len[k<<1|1]+1); 
		solve(k<<1|1,mid+1,r,a,R,ans);
	}
	void evaluation(ll *f,ll *a,ll *ans,ll n,ll m){
		getP(a,1,1,m); 
		if(n>m){
			static ll t[N];
			div(f,p[1],t,f,n,m+1);
		}
		solve(1,1,m,a,f,ans);
	}
	void solve(ll k,ll l,ll r,const ll *x){
		if(l==r){
			D[k]=new ll[1];
			D[k][0]=v[l];
			return;
		}
		ll mid=l+r>>1;
		solve(k<<1,l,mid,x);
		solve(k<<1|1,mid+1,r,x);
		D[k]=new ll[Len[k]];
		init(Len[k]);
		static ll f1[N],f2[N],p1[N],p2[N];
		for(ll i=0;i<Len[k<<1];i++)f1[i]=D[k<<1][i];
		for(ll i=Len[k<<1];i<lmt;i++)f1[i]=0;
		for(ll i=0;i<Len[k<<1|1];i++)f2[i]=D[k<<1|1][i];
		for(ll i=Len[k<<1|1];i<lmt;i++)f2[i]=0;
		for(ll i=0;i<=Len[k<<1];i++)p1[i]=p[k<<1][i];
		for(ll i=Len[k<<1]+1;i<lmt;i++)p1[i]=0;
		for(ll i=0;i<=Len[k<<1|1];i++)p2[i]=p[k<<1|1][i];
		for(ll i=Len[k<<1|1]+1;i<lmt;i++)p2[i]=0;
		mul(f1,p2,Len[k]);
		mul(f2,p1,Len[k]);
		for(ll i=0;i<Len[k];i++)D[k][i]=(f1[i]+f2[i])%P;
	}
	void interpolation(ll *x,ll *y,ll *f,ll n){
		ll len=n;
		getP(x,1,1,n);
		getdev(p[1],C,n+1);
		solve(1,1,n,x,C,v);
		for(ll i=1;i<=n;i++)v[i]=y[i]*qpow(v[i],P-2)%P;
		solve(1,1,n,x);
		for(ll i=0;i<n;i++)f[i]=D[1][i];
		clear(v);
	}
	void polytoffp(ll *f,ll *g,ll len){
		for(ll i=1;i<=len;i++)E[i]=i-1;
		clear(g);
		evaluation(f,E,g,len,len);
		for(ll i=0;i<len;i++)g[i]=g[i+1]*invfac[i],E[i]=(i&1?P-invfac[i]:invfac[i]);
		E[len]=g[len]=0;
		mul(g,E,len<<1);
		clear(E);
	}
}
ll prep(ll n){
    ll lmt=1;
    inv[0]=inv[1]=1;
    while(lmt<n)lmt<<=1;
    for(ll i=2;i<lmt;i++)inv[i]=(P-P/i)*inv[P%i]%P;
    fac[0]=invfac[0]=1;
    for(ll i=1;i<lmt;i++)fac[i]=fac[i-1]*i%P,invfac[i]=invfac[i-1]*inv[i]%P;
    return lmt;
}
namespace Poly2{//模数不是NTT模数 
    int lmt,rev[N];
    struct comp{
        long double x,y;
        comp(long double a=0,long double b=0){x=a,y=b;}
    }a[N],b[N],c[N],d[N];
    comp operator+(comp a,comp b){return comp(a.x+b.x,a.y+b.y);}
    comp operator-(comp a,comp b){return comp(a.x-b.x,a.y-b.y);}
    comp operator*(comp a,comp b){return comp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
    comp operator/(comp a,int t){return comp(a.x/t,a.y/t);}
    inline void init(int n){
        lmt=1;int t=0;
        while(lmt<n)lmt<<=1,t++;
        for(int i=1;i<lmt;i++)rev[i]=(rev[i>>1]>>1)|(i&1)<<(t-1);
    }
    inline void FFT(comp*A,int lmt,int tp){
        for(int i=0;i<lmt;i++)if(i<rev[i])swap(A[i],A[rev[i]]);
        for(int mid=1;mid<lmt;mid<<=1){
            comp Wn(cos(Pi/mid),tp*sin(Pi/mid));
            for(int R=mid<<1,j=0;j<lmt;j+=R){
                comp w(1,0);
                for(int k=0;k<mid;k++,w=w*Wn){
                    comp x=A[j+k],y=w*A[j+k+mid];
                    A[j+k]=x+y,A[j+k+mid]=x-y;
                }
            }
        }
    }
    void MTT(int*f,int*g,int*ans,int n,int m){
        init(n+m);
        const int lim=(1<<15)-1;
        for(int i=0;i<n;i++)a[i]=comp(f[i]&lim,f[i]>>15);
        for(int i=n;i<lmt;i++)a[i]=comp();
        for(int i=0;i<m;i++)b[i]=comp(g[i]&lim,g[i]>>15);
        for(int i=m;i<lmt;i++)b[i]=comp();
        FFT(a,lmt,1),FFT(b,lmt,1);
        for(int i=0;i<lmt;i++){
            int t=(lmt-i)&(lmt-1);
            c[i]=comp((a[i].x+a[t].x)*0.5,(a[i].y-a[t].y)*0.5)*b[i];
            d[i]=comp((a[i].y+a[t].y)*0.5,(a[t].x-a[i].x)*0.5)*b[i];
        }
        FFT(c,lmt,-1),FFT(d,lmt,-1);
        for(int i=0;i<lmt;i++)c[i]=c[i]/lmt,d[i]=d[i]/lmt;
        for(int i=0;i<lmt;i++){
            ll p=c[i].x+0.5,o=c[i].y+0.5,x=d[i].x+0.5,u=d[i].y+0.5;
            ans[i]=(p%P+((o+x)%P<<15)+(u%P<<30))%P;
        }
    }
}
#endif
posted @ 2021-01-14 16:13  happydef  阅读(112)  评论(0编辑  收藏  举报