多项式合集

多项式乘法

FFT

这里

NTT

可以求出两个多项式相乘结果系数对任意NTT模数(可以表示为\(a\times2^b+1\)形式的质数)取模的结果。

其实只要把FFT里的单位副根变为该模数的原根就好了。

常见的NTT模数为998244353,原根为3。

多项式求逆

这里

多项式板子

包括了NTT,求逆,ln,exp,快速幂。

看起来并不是非常高效

namespace poly{
	#define vi vector<int>
	#define ci const int&
	#define Red(x) (x+=(x>>31)&mod)
	const int LM=(1<<22),mod=998244353;
	int lm,lg[LM+10],rev[LM+10],rt[LM+10][2],iv[LM+10],*p,*q;
	int POW(int x,int y){
		int ret=1;
		while(y)y&1?ret=1ll*ret*x%mod:0,x=1ll*x*x%mod,y>>=1;
		return ret;
	}
	void NTT(vi&f,ci op){
		int tn=f.size(),l=lg[tn],r,t1,t2;
		long long nr;
		for(int i=0;i<tn;++i)rev[i]=(rev[i>>1]>>1)+(i&1)*(1<<l-1),rev[i]<i?swap(f[rev[i]],f[i]),0:0;
		for(int i=2;i<=tn;i<<=1){
			r=rt[i][op];
			for(int j=0;j<tn;j+=i){
				nr=1,p=&f[j],q=&f[j+(i>>1)];
				for(int k=j;k<j+(i>>1);++k,nr=nr*r%mod,++p,++q)t1=*p,t2=nr*(*q)%mod,(*p)-=(((*p)=t1+t2)>=mod?mod:0),(*q)=t1-t2,Red((*q));
			}
		}
		if(op)for(int i=0;i<tn;++i)f[i]=1ll*f[i]*iv[tn]%mod;
	}
	vi Poly(ci x){
		vi ret;
		return ret.push_back(x),ret;
	}
	vi Plus(vi x,vi y){
		int sz=max(x.size(),y.size());
		x.resize(sz),y.resize(sz);
		for(int i=0;i<sz;++i)(x[i]+=y[i])>=mod?x[i]-=mod:0;
		return x;
	}
	vi Minus(vi x,vi y){
		int sz=max(x.size(),y.size());
		x.resize(sz),y.resize(sz);
		for(int i=0;i<sz;++i)x[i]-=y[i],Red(x[i]);
		return x;
	}
	vi Mul(vi x,ci y){
		for(int i=0;i<x.size();++i)x[i]=1ll*x[i]*y%mod;
		return x;
	}
	vi Mul(vi x,vi y,ci sz){
		int tl=x.size()+y.size()-1,lth=1;
		while(lth<tl)lth<<=1;
		x.resize(lth),y.resize(lth),NTT(x,0),NTT(y,0);
		for(int i=0;i<lth;++i)x[i]=1ll*x[i]*y[i]%mod;
		NTT(x,1),x.resize(sz);
		return x;
	}
	vi Inv(vi x){
		if(x.size()==1)return x[0]=POW(x[0],mod-2),x;
		vi tmp=x;
		int ts=x.size(),sz=(ts+1>>1);
		tmp.resize(sz),tmp=Inv(tmp);
		int tl=ts+tmp.size()+tmp.size()-2,lth=1;
		while(lth<tl)lth<<=1;
		x.resize(lth),tmp.resize(lth),NTT(x,0),NTT(tmp,0);
		for(int i=0;i<lth;++i)tmp[i]=(2-1ll*x[i]*tmp[i])%mod*tmp[i]%mod,Red(tmp[i]);
		NTT(tmp,1);
		return tmp.resize(ts),tmp;
	}
	vi Ln(vi x){
		vi tmp=x;
		for(int i=1;i<tmp.size();++i)tmp[i-1]=1ll*i*tmp[i]%mod;
		tmp.pop_back(),tmp=Mul(tmp,Inv(x),x.size());
		for(int i=x.size()-1;i>0;--i)tmp[i]=1ll*tmp[i-1]*iv[i]%mod;
		tmp[0]=0;
		return tmp;
	}
	vi Exp(vi x){
		if(x.size()==1)return x[0]=1,x;
		int sz=(x.size()+1>>1);
		vi tmp=x,t2;
		tmp.resize(sz),tmp=Exp(tmp),t2=tmp;
		t2.resize(x.size());
		return Mul(tmp,Plus(Minus(Poly(1),Ln(t2)),x),x.size());
	}
	vi POW(vi x,ci y,ci yc){
		int pw=0,in,ls,ns=x.size();
		while(pw<ns&&!x[pw])++pw;
		if(pw==ns)return x;
		if(1ll*pw*y>=ns){
			vi ret;
			return ret.resize(ns),ret;
		}
		for(int i=pw;i<x.size();++i)x[i-pw]=x[i];
		x.resize(ns-pw);
		in=POW(ls=x[0],mod-2),x=Mul(x,in);
		vi tmp=Exp(Mul(Ln(x),y));
		ls=POW(ls,yc),tmp=Mul(tmp,ls);
		vi ret;ret.resize(ns);
		for(int i=0;i+pw*y<ns;++i)ret[i+pw*y]=tmp[i];
		return ret;
	}
	/*vi val[LM<<1],vl;
	int cnt;
	void Solve(ci l,ci r){
		++cnt,val[cnt].resize(0);
		if(l>r)return(void)(val[cnt].push_back(1));
		if(l==r)return(void)(val[cnt].push_back(mod-l),val[cnt].push_back(1));
		int id=cnt,mid=l+r>>1,lc=cnt+1,rc;
		Solve(l,mid),rc=cnt+1,Solve(mid+1,r);
		val[id]=Mul(val[lc],val[rc],val[lc].size()+val[rc].size()-1);
	}
	vi Calc(vi x){
		vl=x;
		cnt=0,Solve(0,x.size()-1);
		
	}*/
	void init(ci x){
		lm=1;
		while(lm<x)lm<<=1;
		for(int i=2;i<=lm;++i)lg[i]=lg[i>>1]+1,lg[i]!=lg[i-1]?rt[i][0]=POW(3,(mod-1)/i),rt[i][1]=POW(rt[i][0],mod-2):0;
		for(int i=1;i<=lm;++i)iv[i]=POW(i,mod-2);
	}
}
posted @ 2019-03-25 11:43  xryjr233  阅读(414)  评论(3编辑  收藏  举报