多项式板子

#include<bits/stdc++.h>
#define rep(i,a,b) for (register int i=(a);i<=(b);i++)
#define drep(i,a,b) for (register int i=(a);i>=(b);i--)
typedef long long ll;
using namespace std;
inline ll read()
{
	ll sum=0,f=0; char c=getchar();
	while (!isdigit(c)) f|=(c=='-'),c=getchar();
	while (isdigit(c)) sum=(sum<<1)+(sum<<3)+(c^48),c=getchar();
	return f?-sum:sum;
}

const int mod=998244353;
inline int qmo(int x){return x+((x>>31)&mod);}
inline int add(int x,int y){return qmo(x+y-mod);}
inline int sub(int x,int y){return qmo(x-y);}
inline void inc(int &x,int y){x=add(x,y);}
inline void dec(int &x,int y){x=sub(x,y);}
inline int quick_pow(int x,int k){int res=1; for (;k;k>>=1,x=1ll*x*x%mod) if (k&1) res=1ll*res*x%mod; return res;}

vector<int> __inv{1,1};

inline int getinv(int x)
{
	if (x>=(1<<21)) return quick_pow(x,mod-2);
	while ((int)__inv.size()<x+1) __inv.push_back(1ll*(mod-mod/__inv.size())*__inv[mod%__inv.size()]%mod);
	return __inv[x];
}

mt19937 rng(time(0));
inline int rnd(int l,int r) {return l+rng()%(r-l+1);}

int CipollaVal;
struct CipollaComplex{int x,y;};
inline CipollaComplex operator * (CipollaComplex x,CipollaComplex y)
{
	return (CipollaComplex){add(1ll*x.x*y.x%mod,1ll*CipollaVal*x.y%mod*y.y%mod),add(1ll*x.x*y.y%mod,1ll*x.y*y.x%mod)};
}	
inline CipollaComplex quick_pow(CipollaComplex x,int k){CipollaComplex res; res.x=1,res.y=0; for (;k;k>>=1,x=x*x) if (k&1) res=res*x; return res;}

inline int Cipolla(int x)
{
	int a;
	while ((a=rnd(1,1e9))&&(quick_pow(sub(1ll*a*a%mod,x),(mod-1)/2)==1));
	CipollaVal=sub(1ll*a*a%mod,x); 
	int v=quick_pow(CipollaComplex{a,1},(mod+1)/2).x;
	return min(v,mod-v);
}

namespace Poly
{
	const int __G__=3;
	
	vector<int> rt;
	
	inline void init(int lg)
	{
		rt.resize((1<<lg)+1);
		rt[0]=1,rt[1<<lg]=quick_pow(__G__,(mod-1)>>(lg+2));
		drep(i,lg,1) rt[1<<(i-1)]=1ll*rt[1<<i]*rt[1<<i]%mod;
		rep(i,1,1<<lg) rt[i]=1ll*rt[i&(i-1)]*rt[i&(-i)]%mod;
	}
	
	inline void dif(vector<int> &a)
	{
		int limit=a.size();
		for (int len=limit>>1;len;len>>=1)
		{
			for (int j=0,*w=rt.data();j<limit;j+=(len<<1),w++)
			{
				for (int k=j,r;k<j+len;k++)
				{
					r=1ll*(*w)*a[k+len]%mod;
					a[k+len]=sub(a[k],r);
					inc(a[k],r);
				}
			}
		}
	}
	
	inline void dit(vector<int> &a)
	{
		int limit=a.size();
		for (int len=1;len<limit;len<<=1)
		{
			for (int j=0,*w=rt.data();j<limit;j+=(len<<1),w++)
			{
				for (int k=j,r;k<j+len;k++)
				{
					r=add(a[k],a[k+len]);
					a[k+len]=1ll*sub(a[k],a[k+len])*(*w)%mod;
					a[k]=r;
				}
			}
		}
		reverse(a.begin()+1,a.end());
		rep(i,0,limit-1) a[i]=1ll*a[i]*getinv(limit)%mod;
	} 
	
	struct Poly
	{
		vector<int> a;
		Poly () {}
		Poly (const vector<int> &x):a(x) {}
		Poly (const initializer_list<int> &x):a(x) {}
		inline int size() const {return a.size();}
		inline void resize(int n) {a.resize(n);}
		inline int operator [] (int n) const
		{
			if (n<0||n>=size()) return 0;
			return a[n];
		}
		inline Poly reverse() const {return Poly(vector<int>(a.rbegin(),a.rend()));}
		inline Poly mulxn(int n) const {auto b=a; b.insert(b.begin(),n,0); return Poly(b);}
		inline Poly divxn(int n) const {if (n>=size()) return Poly(); return Poly(vector<int>(a.begin()+n,a.end()));}
		inline Poly modxn(int n) const {if (!size()) return Poly(); int k=min(size(),n); return Poly(vector<int>(a.begin(),a.begin()+k));}
		inline Poly shrink() const {if (!size()) return Poly(); int lst=size()-1; while (lst>=0&&!a[lst]) lst--; return Poly(vector<int>(a.begin(),a.begin()+lst+1));}
		inline friend Poly operator + (const Poly &a,const Poly &b)
		{
			vector<int> res(max(a.size(),b.size()));
			rep(i,0,(int)res.size()-1) res[i]=add(a[i],b[i]);
			return Poly(res);
		}
		inline friend Poly operator - (const Poly &a,const Poly &b)
		{
			vector<int> res(max(a.size(),b.size()));
			rep(i,0,(int)res.size()-1) res[i]=sub(a[i],b[i]);
			return Poly(res);
		}
		inline friend Poly operator * (Poly a,Poly b)
		{
			if (!a.size()||!b.size()) return Poly();
			if (a.size()<=40||b.size()<=40)
			{
				if (a.size()>b.size()) swap(a,b);
				vector<int> res(a.size()+b.size()-1);
				rep(i,0,(int)(res.size()-1))
				{
					for (int j=max(0,i-b.size()+1);j<=i&&j<a.size();j++) inc(res[i],1ll*a[j]*b[i-j]%mod);
				}
				return Poly(res).shrink();
			}
			int limit=1,sz=a.size()+b.size()-1;
			while (limit<sz) limit<<=1; a.a.resize(limit),b.a.resize(limit);
			dif(a.a),dif(b.a); 
			rep(i,0,limit-1) a.a[i]=1ll*a.a[i]*b.a[i]%mod;
			dit(a.a);
			return a.shrink(); 
		}
		inline friend Poly operator * (Poly a,int b) {rep(i,0,a.size()-1) a.a[i]=1ll*a.a[i]*b%mod; return a;}
		inline friend Poly operator * (int a,Poly b) {rep(i,0,b.size()-1) b.a[i]=1ll*b.a[i]*a%mod; return b;}
		inline Poly& operator += (Poly b) {return (*this)=(*this)+b;}
		inline Poly& operator -= (Poly b) {return (*this)=(*this)-b;}
		inline Poly& operator *= (Poly b) {return (*this)=(*this)*b;}
		inline Poly& operator *= (int b) {return (*this)=(*this)*b;}
		inline friend bool operator == (const Poly &a,const Poly &b) {rep(i,0,max(a.size(),b.size())-1) if (a[i]!=b[i]) return false; return true;}
		inline Poly deriv() const
		{
			if (!size()) return Poly();
			vector<int> res(size()-1);
			rep(i,0,size()-2) res[i]=1ll*a[i+1]*(i+1)%mod;
			return Poly(res);
		}
		inline Poly integ() const
		{
			vector<int> res(size()+1);
			rep(i,0,size()-1) res[i+1]=1ll*a[i]*getinv(i+1)%mod;
			return Poly(res);
		}
		inline Poly inv(int n) const
		{
			Poly res{getinv(a[0])},tmp;
			int k=1;
			while (k<n)
			{
				k<<=1; int limit=k<<1; tmp.resize(limit); res.resize(limit);
				rep(i,0,k-1) tmp.a[i]=(*this)[i];
				dif(tmp.a),dif(res.a);
				rep(i,0,limit-1) res.a[i]=1ll*res[i]*sub(2,1ll*tmp[i]*res[i]%mod)%mod;
				dit(res.a);
				rep(i,k,limit-1) res.a[i]=0;
				rep(i,0,limit-1) tmp.a[i]=0;
			}
			return res.modxn(n);
		}
		inline Poly sqrt(int n) const
		{
			Poly x{Cipolla(a[0])};
			int k=1;
			while (k<n)
			{
				k<<=1;
				x=(x+(modxn(k)*x.inv(k))).modxn(k)*getinv(2);
			}
			return x.modxn(n);
		}
		inline Poly ln(int n) const {return (modxn(n).deriv()*inv(n)).modxn(n).integ().modxn(n);}
		inline Poly exp(int n) const
		{
			Poly res{1};
			int k=1;
			while (k<n)
			{
				k<<=1;
				res=(res*(Poly{1}-res.ln(k)+modxn(k))).modxn(k);
			}
			return res.modxn(n);
		}
		inline Poly pow(int k,int n) const
		{
			int i=0; while (i<size()&&!a[i]) i++;
			if (i==size()||1ll*i*k>=n) return Poly();
			Poly x=quick_pow(a[i],mod-2)*divxn(i);
			return (x.ln(n-i*k)*k).exp(n-i*k).mulxn(i*k)*quick_pow(a[i],k); 
		}
		inline pair<Poly,Poly> div(const Poly &o) const
		{
			if (size()<o.size()) return make_pair(Poly(),*this);
			Poly f=(reverse().modxn(size()-o.size()+1)*o.reverse().modxn(size()-o.size()+1).inv(size()-o.size()+1)).modxn(size()-o.size()+1).reverse();
			Poly g=(modxn(o.size()-1)-o.modxn(o.size()-1)*f.modxn(o.size()-1)).modxn(o.size()-1);
			return make_pair(f,g);
		}
	};
}
using Poly::Poly;

下面是数组版的

const int mod=998244353;
inline int qmo(int x){return x+((x>>31)&mod);}
inline int add(int x,int y){return qmo(x+y-mod);}
inline int sub(int x,int y){return qmo(x-y);}
inline void inc(int &x,int y){x=add(x,y);}
inline void dec(int &x,int y){x=sub(x,y);}
inline int quick_pow(int x,int k){int res=1;for (;k;k>>=1,x=1ll*x*x%mod) if (k&1) res=1ll*res*x%mod; return res;}

namespace Poly
{
	const int img=86583718;
	int limit,G,Ginv,rev[N];
	inline void PolyInit(int size)
	{
		limit=1; int l=0;
		while (limit<=size) limit<<=1,l++;
		rep(i,0,limit-1) rev[i]=((rev[i>>1]>>1)|(i&1)<<(l-1));
		G=3,Ginv=quick_pow(G,mod-2);
	}
	inline void NTT(int *A,int tp)
	{
		rep(i,0,limit-1) if (i<rev[i]) swap(A[i],A[rev[i]]);
		for (int mid=1;mid<limit;mid<<=1)
		{
			int Wn=quick_pow(tp==1?G:Ginv,mod/(mid<<1));
			for (int j=0;j<limit;j+=(mid<<1))
			{
				int w=1;
				rep(k,0,mid-1)
				{
					int x=A[j+k],y=1ll*A[j+k+mid]*w%mod;
					A[j+k]=add(x,y),A[j+k+mid]=sub(x,y);
					w=1ll*w*Wn%mod;
				}
			}
		}
		if (tp==-1)
		{
			int I=quick_pow(limit,mod-2);
			rep(i,0,limit-1) A[i]=1ll*A[i]*I%mod;
		}
	}
	inline void PolyMul(int *target,int *A,int *B,int n)
	{
		static int tmp[N],tmp1[N],tmp2[N];
		PolyInit(n<<1);
		rep(i,0,limit-1) tmp1[i]=A[i],tmp2[i]=B[i];
		NTT(tmp1,1),NTT(tmp2,1);
		rep(i,0,limit-1) tmp[i]=1ll*tmp1[i]*tmp2[i]%mod;
		NTT(tmp,-1);
		rep(i,0,n-1) target[i]=tmp[i];
	}
	void PolyInv(int *A,int *B,int n)
	{
		static int tmp[N];
		if (n==1) return B[0]=quick_pow(A[0],mod-2),void();
		int mid=(n+1)>>1; PolyInv(A,B,mid); PolyInit(n<<1);
		rep(i,0,n-1) tmp[i]=A[i];
		NTT(B,1),NTT(tmp,1);
		rep(i,0,limit-1) B[i]=1ll*B[i]*sub(2,1ll*tmp[i]*B[i]%mod)%mod;
		NTT(B,-1);
		rep(i,n,limit-1) B[i]=0;
		rep(i,0,limit-1) tmp[i]=0;
	}
	int fac[N],_fac[N],inv[N];
	inline void init(int size)
	{
		fac[0]=_fac[0]=inv[0]=1;
		rep(i,1,size) fac[i]=1ll*fac[i-1]*i%mod;
		_fac[size]=quick_pow(fac[size],mod-2);
		drep(i,size-1,1) _fac[i]=1ll*_fac[i+1]*(i+1)%mod;
		rep(i,1,size) inv[i]=1ll*fac[i-1]*_fac[i]%mod;
	}
	inline void PolyDer(int *A,int *B,int n)
	{
		static int tmp[N];
		rep(i,0,n-2) tmp[i]=1ll*A[i+1]*(i+1)%mod; tmp[n-1]=0;
		rep(i,0,n-1) B[i]=tmp[i];
	}
	inline void PolyInt(int *A,int *B,int n)
	{
		static int tmp[N];
		rep(i,1,n-1) tmp[i]=1ll*A[i-1]*inv[i]%mod; tmp[0]=0;
		rep(i,0,n-1) B[i]=tmp[i];
	}
	inline void PolyLn(int *A,int *B,int n)
	{
		static int tmp1[N],tmp2[N];
		init(n); PolyInv(A,tmp2,n); PolyDer(A,tmp1,n);
		PolyInit(n<<1); NTT(tmp1,1),NTT(tmp2,1);
		rep(i,0,limit-1) tmp1[i]=1ll*tmp1[i]*tmp2[i]%mod;
		NTT(tmp1,-1); PolyInt(tmp1,B,n);
		rep(i,0,limit-1) tmp1[i]=tmp2[i]=0;
	}
	inline void PolyExp(int *A,int *B,int n)
	{
		static int tmp1[N],tmp2[N];
		if (n==1) return B[0]=1,void();
		int mid=(n+1)>>1; PolyExp(A,B,mid);
		rep(i,0,n-1) tmp2[i]=A[i];
		PolyLn(B,tmp1,n); PolyInit(n<<1); NTT(B,1),NTT(tmp1,1),NTT(tmp2,1);
		rep(i,0,limit-1) B[i]=1ll*B[i]*sub(1,sub(tmp1[i],tmp2[i]))%mod;
		NTT(B,-1);
		rep(i,n,limit-1) B[i]=0;
		rep(i,0,limit-1) tmp1[i]=tmp2[i]=0;
	}
	inline void PolyPow(int *A,int *B,int n,int k)
	{
		static int tmp[N];
		PolyLn(A,tmp,n);
		rep(i,0,n-1) tmp[i]=1ll*tmp[i]*k%mod;
		PolyExp(tmp,B,n);
		PolyInit(n<<1);
		rep(i,0,limit-1) tmp[i]=0;
	}
	inline void PolySin(int *A,int *B,int n)
	{
		static int tmp1[N],tmp2[N],tmp3[N],tmp4[N];
		memset(tmp1,0,sizeof(int)*n);
		memset(tmp2,0,sizeof(int)*n);
		memset(tmp3,0,sizeof(int)*n);
		memset(tmp4,0,sizeof(int)*n);
		rep(i,0,n-1) tmp1[i]=1ll*A[i]*img%mod,tmp2[i]=1ll*A[i]*sub(0,img)%mod;
		PolyExp(tmp1,tmp3,n),PolyExp(tmp2,tmp4,n);
		const int _inv=quick_pow(add(img,img),mod-2);
		rep(i,0,n-1) B[i]=1ll*sub(tmp3[i],tmp4[i])*_inv%mod;
	}
	inline void PolyCos(int *A,int *B,int n)
	{
		static int tmp1[N],tmp2[N],tmp3[N],tmp4[N];
		memset(tmp1,0,sizeof(int)*n);
		memset(tmp2,0,sizeof(int)*n);
		memset(tmp3,0,sizeof(int)*n);
		memset(tmp4,0,sizeof(int)*n);
		rep(i,0,n-1) tmp1[i]=1ll*A[i]*img%mod,tmp2[i]=1ll*A[i]*sub(0,img)%mod;
		PolyExp(tmp1,tmp3,n),PolyExp(tmp2,tmp4,n);
		const int _inv=quick_pow(2,mod-2); 
		rep(i,0,n-1) B[i]=1ll*add(tmp3[i],tmp4[i])*_inv%mod;
	}
	inline void PolyTan(int *A,int *B,int n)
	{
		static int tmp1[N],tmp2[N],tmp3[N],tmp4[N];
		memset(tmp1,0,sizeof(int)*n);
		memset(tmp2,0,sizeof(int)*n);
		memset(tmp3,0,sizeof(int)*n);
		memset(tmp4,0,sizeof(int)*n);
		rep(i,0,n-1) tmp1[i]=1ll*A[i]*img%mod,tmp2[i]=1ll*A[i]*sub(0,img)%mod;
		PolyExp(tmp1,tmp3,n),PolyExp(tmp2,tmp4,n);
		const int _inv1=quick_pow(add(img,img),mod-2),_inv2=quick_pow(2,mod-2);
		rep(i,0,n-1) tmp1[i]=1ll*sub(tmp3[i],tmp4[i])*_inv1%mod;
		rep(i,0,n-1) tmp2[i]=1ll*add(tmp3[i],tmp4[i])*_inv2%mod;
		memset(tmp3,0,sizeof(int)*n);
		PolyInv(tmp2,tmp3,n); PolyMul(B,tmp1,tmp3,n); 
	} 
}
using namespace Poly;
posted @ 2021-04-07 13:28  ZSH_ZSH  阅读(90)  评论(1编辑  收藏  举报