多项式运算封装

update on 21.12.30:添加了 polyeva;修补了 polymod 处多测时可能产生的 bug。

update on 22.2.7:重 写(前一版太丑了),改为完全封装版本(使用 std::vector<int> 存放多项式系数,运算在命名空间 polynomial:: 里)

update on 22.2.11 第二版增加了快速求值插值(但是使用的是利用多项式取模进行降次的做法,常数巨大,关于转置原理的做法见作者在 Luogu 5050 的提交记录),


实现的并不优秀,但应该很稳(雾)。全部提交测试过,部分由 Vincra 这个账号提交。

索引:

Ver 2(施工中)

使用 poly A; 来声明一个名为 \(A\) 的多项式,其原型为 vector<int> a;,较 Ver 1 更为优美、精简,去掉了没啥用、并且很丑的常数项不为 \(1\) 时的多项式快速幂。

下面是 P5158 的代码:

const int djq=998244353;
inline int ksm(int base,int p){
	int ret=1;
	while(p){
		if(p&1) ret=1ll*ret*base%djq;
		base=1ll*base*base%djq,p>>=1;
	}
	return ret;
}
namespace polynomial{
    typedef vector<int> poly;
    const int G=3,invG=ksm(G,djq-2),inv2=ksm(2,djq-2);
    int rev[1200005];
    inline int BSGS(int a,int b){
        map<int,int> mp; b%=djq;
        int t=ceil(sqrt((double)djq)),epow=1;
        for(rg int j=0;j<t;++j,epow=1ll*epow*a%djq) mp[1ll*b*epow%djq]=j;
        a=epow,epow=1;
        for(rg int i=0;i<=t;++i,epow=1ll*epow*a%djq) if(mp.find(epow)!=mp.end()&&1ll*i*t-mp[epow]>=0) return 1ll*i*t-mp[epow];
        return b;
    }
    inline int modsqrt(int x){ return (!x||x==1)?x:ksm(3,BSGS(3,x)>>1); }
    inline int initrev(const int n){
        int len=1,lgn=0;
        while(len<=n) len<<=1,++lgn;
        for(rg int i=0;i<len;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lgn-1));
        return len;
    }
    inline void NTT(poly &A,const int len,const int opt){
        A.resize(len);
        for(rg int i=0;i<len;++i) if(i<rev[i]) swap(A[i],A[rev[i]]);
        for(rg int i=1;i<len;i<<=1){
            const int g=ksm(opt?G:invG,(djq-1)/(i<<1));
            for(rg int j=0,mid=(i<<1);j<len;j+=mid){
            	int gn=1;
                for(rg int k=0;k<i;++k,gn=1ll*gn*g%djq){
                    const int x=A[j+k],y=1ll*gn*A[i+j+k]%djq;
                    A[j+k]=(x+y)%djq,A[i+j+k]=(x-y+djq)%djq;
                }
            }
        }
        const int invlen=ksm(len,djq-2);
        if(!opt) for(rg int i=0;i<len;++i) A[i]=1ll*A[i]*invlen%djq;
    }
    poly mul(poly A,poly B){
        const int n=A.size()+B.size()-1,len=initrev(n);
        poly C; C.resize(len);
        NTT(A,len,1),NTT(B,len,1);
        for(rg int i=0;i<len;++i) C[i]=1ll*A[i]*B[i]%djq;
        NTT(C,len,0),C.resize(n);
        return C;
    }
    poly inv(poly F,int n){
        if(n==1) return poly(1,ksm(F[0],djq-2));
        poly A(F.begin(),F.begin()+n);
        poly B=inv(F,n+1>>1);
        int len=initrev(n<<1);
        NTT(A,len,1),NTT(B,len,1);
        for(rg int i=0;i<len;++i) A[i]=1ll*B[i]*(2-1ll*A[i]*B[i]%djq+djq)%djq;
        return NTT(A,len,0),A.resize(n),A;
    }
    poly mod(poly F,poly G){
    	poly RF=F,RG=G,H; const int n=F.size(),m=G.size();
		reverse(RF.begin(),RF.end()),reverse(RG.begin(),RG.end());
		RF.resize(n-m+1),RG.resize(n-m+1);
		H=mul(RF,inv(RG,n-m+1)),H.resize(n-m+1),reverse(H.begin(),H.end());
		H=mul(H,G),H.resize(m-1);
		for(rg int i=0;i<m-1;++i) H[i]=(F[i]-H[i]+djq)%djq;
		return H;
    }
    poly dif(poly A){
        const int n=A.size();
        for(rg int i=1;i<n;++i) A[i-1]=1ll*A[i]*i%djq;
        return A.resize(n-1),A;
    }
    poly idif(poly A){
        const int n=A.size(); A.resize(n+1);
        for(int i=n;i;--i) A[i]=1ll*A[i-1]*ksm(i,djq-2)%djq;
        return A[0]=0,A;
    }
    poly ln(poly A,int n){
        poly B=idif(mul(dif(A),inv(A,n)));
        return B.resize(n),B;
    }
    poly exp(poly F,int n) {
        if(n==1) return poly(1,1);
        poly A=exp(F,n+1>>1); A.resize(n);
        poly B=ln(A,n);
        for(rg int i=0;i<n;++i) B[i]=(F[i]-B[i]+djq)%djq;
        const int len=initrev(n<<1);
        NTT(A,len,1),NTT(B,len,1);
        for(rg int i=0;i<len;++i) A[i]=1ll*A[i]*(1+B[i])%djq;
        return NTT(A,len,0),A.resize(n),A;
    }
    poly sqrt(poly F, int n) {
        if(n==1) return poly(1,modsqrt(F[0]));
        poly A(F.begin(),F.begin()+n);
        poly B=sqrt(F,n+1>>1),C=inv(B,n);
        const int len=initrev(n<<1);
        NTT(A,len,1),NTT(C,len,1);
        for(rg int i=0;i<len;++i) A[i]=1ll*A[i]*C[i]%djq;
        NTT(A,len,0);
        for(rg int i=0;i<n;++i) A[i]=1ll*(A[i]+B[i])*inv2%djq;
        return A.resize(n),A;
    }
    poly qpow(poly A,const int k,const int n){
        A.resize(n); A=ln(A,n); 
        for(rg int i=0;i<n;++i) A[i]=1ll*A[i]*k%djq;
        return exp(A,n);
    }
}
using namespace polynomial;
int ex[100005],ey[100005];
poly p[100005<<2];
void evabuild(int x,int l,int r){
	if(l==r) return p[x].resize(2),p[x][0]=(djq-ex[l])%djq,p[x][1]=1,void();
	const int mid=l+r>>1;
	evabuild(x<<1,l,mid),evabuild(x<<1|1,mid+1,r),p[x]=mul(p[x<<1],p[x<<1|1]);
}
void evasolve(int x,int l,int r,poly A){
	if(l==r) return ey[l]=A[0],void();
	const int mid=l+r>>1;
	poly tmp=mod(A,p[x<<1]); evasolve(x<<1,l,mid,tmp);
	tmp=mod(A,p[x<<1|1]); evasolve(x<<1|1,mid+1,r,tmp);
}
poly g;
int n,x[100005],y[100005];
poly intsolve(int x,int l,int r){
	if(l==r) return poly(1,1ll*y[l]*ksm(ey[l],djq-2)%djq);
	const int mid=l+r>>1;
	poly tmp1=mul(intsolve(x<<1,l,mid),p[x<<1|1]),tmp2=mul(intsolve(x<<1|1,mid+1,r),p[x<<1]);
	tmp1.resize(r-l+1),tmp2.resize(r-l+1);
	for(rg int i=0;i<r-l+1;++i) tmp1[i]=(tmp1[i]+tmp2[i])%djq;
	return tmp1;
}
int main(){
	n=read();
	for(rg int i=1;i<=n;++i) ex[i]=x[i]=read(),y[i]=read();
	evabuild(1,1,n); g=dif(p[1]); evasolve(1,1,n,g); poly ans=intsolve(1,1,n);
	for(rg int i=0;i<n;++i) printf("%d ",ans[i]);
    return 0;
}

Ver 1

变量/常量:
  1. N:多项式次数。
  2. djq G invG :模数、该模数的一个原根、该原根的逆元。
  3. ar1~ar5:不知道起什么名字的辅助数组。
  4. ivG~imG:看名字大概能猜出来什么用的辅助数组。
  5. rev1 rev2:蝴蝶变换数组。
函数( pulic ):

若无特殊说明,整数、多项式系数均为 \(\pmod {p}\) 意义下。

  1. ksm(a,b):求 \(a^b\);快速幂;\(O(\log b)\)

  2. modsqrt(a):求 \(a^{\frac{1}{2}}\);原根 + BSGS;\(O(\sqrt {p})\)

  3. initrev(n,lgn,l,rev):预处理\(\pmod{x^n}\) 意义下 NTT 所用的蝴蝶变换数组;暴力;\(O(n)\)

  4. NTT(A,n,opt,rev):以 \(rev\) 作为蝴蝶变换数组,对 \(A(x)\pmod{x^n}\) 做快速数论变换(\(n^{-1}\) 放在外面乘);NTT;\(O(n\log n)\)

  5. polymul(A,B,C,n):求 \(C(x)=A(x)B(x)\pmod{x^{2n}}\);NTT;\(O(n\log n)\)

  6. polyinv(A,B,n):求 \(B(x)=A^{-1}(x)\pmod{x^n}\);牛顿迭代 + NTT;\(O(n\log n)\)

  7. polydif(A,B,n):求 \(B(x)=\dfrac{\mathrm{d}A(x)}{\mathrm{d}x}\pmod{x^n}\);微分;\(O(n)\)

  8. polyint(A,B,n):求 \(B(x)=\int A(x)\mathrm{d}x\pmod{x^n}\);积分;\(O(n)\)

  9. polysqrt(A,B,n):求 \(B(x)=A^{\frac{1}{2}}(x)\pmod{x^n}\)(对\(A_0\) 无要求);牛顿迭代 + NTT;\(O(n\log n)\)

  10. polyln(A,B,n):求 \(B(x)=\ln A(x)\pmod{x^n}\)(由麦克劳林级数定义);复合函数微分 + 积分 + NTT;\(O(n\log n)\)

  11. polyexp(A,B,n):求 \(B(x)=\exp A(x)\pmod{x^n}\)(由麦克劳林级数定义);牛顿迭代 + polyln\(O(n\log n)\)

  12. polymod(A,B,C,n,m,D):有 \(A(x)\pmod{x^n}\)\(B(x)\pmod{x^m}\) (\(m<n\)),求 \(D(x)=A(x)\bmod{B(x)}\pmod{x^{m-1}}\)\(C(x)=B^{-1}(x)(A(x)-D(x))\pmod{x^{n-m+1}}\);暴力 + polyinv\(O(n\log n)\)

  13. polyksm(A,B,k,n):求 \(B(x)=A^k(x)\pmod{x^n}\)(要求 \(A_0=1\));初中数学 + polyln + polyexp\(O(n\log n)\)

  14. sppolyksm(A,B,k1,k2,lim,n):求 \(B(x)=A^k(x)\pmod{x^n}\)(对\(A_0\) 无要求),其中 \(k1=k\pmod{p}\) , \(k1=k\pmod{\varphi(p)}\) , \(lim=\min(k,p)\)(用于特判);暴力 + polyksm\(O(n\log n)\)

  15. polycdq(A,B,n,val):求 \(B_0=val,B_i=\sum_{j=0}^{i-1}A_{i-j}B_j\pmod{x^n}\) (最基础的半在线卷积);cdq 分治 + NTT;\(O(n\log^2 n)\)

  16. polyeva(A,B,n,m):给出 \(A(x)\pmod{x^n}\)\(x_1\ldots x_m\),输出 \(A(x_1)\ldots A(x_m)\)。(多点快速求值)


const int N=100005,djq=998244353;
class poly{
    private:
        const int G=3,invG=332748118;
        ll ar1[4*N],ar2[4*N],ar3[4*N],ar4[4*N],ar5[4*N],ar6[4*N];
        ll ivG[4*N],sqG[4*N],lnG[4*N],sqtmp[4*N];
        ll expG[4*N],modG[4*N],kmG[4*N];
        int rev1[4*N],rev2[4*N];
        void POLYINV(ll* A,ll* B,const int n){
            if(n==1) return B[0]=ksm(A[0],djq-2),void();
            POLYINV(A,B,(n+1)>>1);
            int l,lgn; initrev(n<<1,lgn,l,rev2);
            copy(A,A+n,ivG);
            fill(ivG+n,ivG+l,0);
            NTT(B,l,1,rev2),NTT(ivG,l,1,rev2);
            for(rg int i=0;i<l;++i) B[i]=B[i]*(2-B[i]*ivG[i]%djq+djq)%djq;
            NTT(B,l,0,rev2);
            const int invn=ksm(l,djq-2);
            for(rg int j=0;j<l;++j) B[j]=B[j]*invn%djq;
            fill(B+n,B+l,0);
        }
        inline int BSGS(int a,int b){
            map<int,int> mp; b%=djq;
            int t=ceil(sqrt((double)djq)),epow=1;
            for(rg int j=0;j<t;++j,epow=1ll*epow*a%djq) mp[1ll*b*epow%djq]=j;
            a=epow,epow=1;
            for(rg int i=0;i<=t;++i,epow=1ll*epow*a%djq)
                if(mp.find(epow)!=mp.end()&&1ll*i*t-mp[epow]>=0)
                    return 1ll*i*t-mp[epow];
            return b;
        }
        void CDQ(ll* F,ll* G,int l,int r){
            if(l==r) return;
            const int mid=l+r>>1;
            CDQ(F,G,l,mid);
            copy(F+l,F+mid+1,ar4),copy(G,G+r-l+1,ar5);
            fill(ar4+mid-l+1,ar4+(r-l+1)*4,0),fill(ar5+r-l+1,ar5+(r-l+1)*4,0);
            polymul(ar4,ar5,ar4,r-l+1);
            for(rg int i=mid+1;i<=r;++i) F[i]=(F[i]+ar4[i-l])%djq;
            CDQ(F,G,mid+1,r);
        }
        ll *p[4*N]; int len[4*N];
        void up(int x){
			len[x]=len[x<<1]+len[x<<1|1],p[x]=new ll[len[x]+1];
        	int l,lgn; initrev(len[x]+1,lgn,l,rev1);
            fill(ar1,ar1+l,0),fill(ar2,ar2+l,0);
            copy(p[x<<1],p[x<<1]+len[x<<1]+1,ar1),copy(p[x<<1|1],p[x<<1|1]+len[x<<1|1]+1,ar2);
            NTT(ar1,l,1,rev1),NTT(ar2,l,1,rev1);
            for(rg int j=0;j<l;++j) ar1[j]=ar1[j]*ar2[j]%djq;
            NTT(ar1,l,0,rev1);
            const int invn=ksm(l,djq-2);
            for(rg int j=0;j<=len[x];++j) p[x][j]=ar1[j]*invn%djq;
            fill(ar1,ar1+l,0),fill(ar2,ar2+l,0);
        }
        void build(int x,int l,int r,ll* A){
        	if(l==r) return len[x]=1,p[x]=new ll[2],p[x][0]=(djq-A[l]%djq)%djq,p[x][1]=1,void();
        	const int mid=l+r>>1;
        	build(x<<1,l,mid,A);
			build(x<<1|1,mid+1,r,A);
        	up(x);
        }
        void solve(int x,int l,int r,ll* A){
        	if(l==r) return printf("%d\n",(int)A[0]),void();
        	const int mid=l+r>>1;
			ll tmp[len[x]*2+5];
			polymod(A,p[x<<1],tmp,len[x],len[x<<1]+1,ar6);
			solve(x<<1,l,mid,tmp);
			polymod(A,p[x<<1|1],tmp,len[x],len[x<<1|1]+1,ar6);
			solve(x<<1|1,mid+1,r,tmp);
        }
    public:
        inline int ksm(ll base,int p){
            ll ret=1;
            while(p){
                if(p&1) ret*=base,ret%=djq;
                base*=base,base%=djq; p>>=1;
            } return ret;
        }
        inline int modsqrt(int x){
            if(!x||x==1) return x;
            return ksm(3,BSGS(3,x)>>1);
        }
        inline void initrev(const int n,int& lgn,int& l,int* rev){
            lgn=0,l=0;
            while((1<<lgn)<n) ++lgn;  l=(1<<lgn);
            fill(rev,rev+l,0);
            for(rg int j=0;j<l;++j) rev[j]=(rev[j>>1]>>1)|((j&1)<<(lgn-1));
        }
        void NTT(ll* A,const int l,const int opt,int* rev){
            for(rg int i=0;i<l;++i) if(i<rev[i]) swap(A[i],A[rev[i]]);
            for(rg int i=1;i<l;i<<=1){
                const ll gn=ksm(opt?G:invG,(djq-1)/(i<<1));
                for(rg int j=0;j<l;j+=(i<<1)){
                    ll g=1;
                    for(rg int k=0;k<i;++k,g=g*gn%djq){
                        const ll x=A[j+k],y=g*A[i+j+k]%djq;
                        A[j+k]=(x+y)%djq,A[i+j+k]=(x-y+djq)%djq;
                    }
                }
            }
        }
        void polymul(ll* A,ll* B,ll* C,const int n){
            int l,lgn; initrev(n<<1,lgn,l,rev1);
            fill(ar1,ar1+l,0),fill(ar2,ar2+l,0);
            copy(A,A+l,ar1),copy(B,B+l,ar2);
            NTT(ar1,l,1,rev1),NTT(ar2,l,1,rev1);
            for(rg int j=0;j<l;++j) ar1[j]=ar1[j]*ar2[j]%djq;
            NTT(ar1,l,0,rev1);
            const int invn=ksm(l,djq-2);
            for(rg int j=0;j<l;++j) C[j]=ar1[j]*invn%djq;
            fill(ar1,ar1+l,0),fill(ar2,ar2+l,0);
        }
        void polyinv(ll* A,ll* B,const int n){
            fill(B,B+4*n,0),POLYINV(A,B,n);
        }
        void polydif(ll* A,ll* B,const int n){
            for(rg int i=0;i<n-1;++i) B[i]=1ll*(i+1)*A[i+1]%djq;
            B[n-1]=0;
        }
        void polyint(ll* A,ll* B,const int n){
            for(rg int i=n-1;i;--i) B[i]=1ll*A[i-1]*ksm(i,djq-2)%djq;
            B[0]=0;
        }
        void polysqrt(ll* A,ll* B,const int n){
            if(n==1) return B[0]=modsqrt(A[0]),void();
            polysqrt(A,B,(n+1)>>1);
            polyinv(B,sqG,n);
            for(rg int j=0;j<n;++j) sqtmp[j]=A[j];
            fill(sqtmp+n,sqtmp+4*n,0);
            polymul(sqG,sqtmp,sqG,n);
            for(rg int i=0;i<n;++i) B[i]=(B[i]+sqG[i])*ksm(2,djq-2)%djq;
            fill(B+n,B+4*n,0);
        }
        void polyln(ll* A,ll* B,const int n){
            fill(B,B+n*4,0);
            polyinv(A,B,n);
            polydif(A,lnG,n);
            fill(lnG+n-1,lnG+4*n,0);
            polymul(B,lnG,B,n);
            fill(B+n,B+4*n,0);
            polyint(B,B,n);
            fill(B+n,B+4*n,0);
        }
        void polyexp(ll* A,ll* B,const int n){
            if(n==1) return B[0]=1,void();
            polyexp(A,B,(n+1)>>1);
            polyln(B,expG,n);
            for(rg int i=0;i<n;++i) expG[i]=(A[i]-expG[i]+djq)%djq;
            expG[0]=(expG[0]+1)%djq;
            fill(expG+n,expG+4*n,0),fill(B+n,B+4*n,0);
            polymul(B,expG,B,n);
            fill(B+n,B+4*n,0);
        }
        void polymod(ll* A,ll* B,ll* C,const int n,const int m,ll* D){
            copy(A,A+n,D),copy(B,B+m,ar3),copy(B,B+m,ar4);
            reverse(D,D+n),reverse(ar3,ar3+m);
            polyinv(ar3,modG,n-m+1),polymul(D,modG,D,n);
            fill(D+n-m+1,D+4*n,0),reverse(D,D+n-m+1);
            fill(ar3,ar3+m,0),polymul(D,ar4,ar3,n);
            for(rg int i=0;i<m-1;++i) C[i]=(A[i]-ar3[i]+djq)%djq;
            fill(C+m-1,C+2*n,0);
            fill(D,D+4*n,0),fill(ar3,ar3+4*n,0),fill(ar4,ar4+4*n,0);
        }
        void polyksm(ll* A,ll* B,const int k,const int n){
            polyln(A,kmG,n);
            for(rg int i=0;i<n;++i) kmG[i]=kmG[i]*k%djq;
            polyexp(kmG,B,n);
        }
        void sppolyksm(ll* A,ll* B,const int k1,const int k2,const int lim,const int n){
            int tmp=n;
            for(rg int i=0;i<n;++i) if(A[i]){ tmp=i; break; }
            if(1ll*tmp*lim>=1ll*n) return fill(B,B+n,0),void();
            for(rg int i=0;i<n-tmp;++i) ar4[i]=A[i+tmp];
            const int pwA=ksm(ar4[0],k2),invA=ksm(ar4[0],djq-2);
            for(rg int i=0;i<n-tmp;++i) ar4[i]=ar4[i]*invA%djq;
            polyln(ar4,kmG,n-tmp);
            for(rg int i=0;i<n-tmp;++i) kmG[i]=kmG[i]*k1%djq;
            polyexp(kmG,B,n-tmp);
            for(rg int i=0;i<n-tmp;++i) B[i]=B[i]*pwA%djq;
            tmp*=lim;
            for(rg int i=n-1;i>=tmp;--i) B[i]=B[i-tmp];
            fill(B,B+tmp,0);
        }
        void polycdq(ll* A,ll* B,const int n,const int va){
            fill(A+n,A+4*n,0),fill(B,B+4*n,0),B[0]=va;
            CDQ(B,A,0,n-1);
        }
        void polyeva(ll* A,ll* x,const int n,const int m){
        	build(1,1,m,x);
        	if(n>m) polymod(A,p[1],A,n,m+1,ar6);
        	solve(1,1,m,A);
        	for(rg int i=1;i<=4*m;++i) if(p[i]!=NULL) delete p[i],p[i]=NULL;
        }
}work;

posted @ 2022-02-11 17:52  Legitimity  阅读(176)  评论(0编辑  收藏  举报