多项式模板

这里给出了一个多项式/形式幂级数的类的实现。由于仅考虑形式幂级数的前n项,即$\!\bmod{x^n}$下的等价类,故其形式与多项式相同,因而在类的实现上没有对两者作区分。没有返回值的方法会将结果用于更新自身。对于所有用于形式幂级数的方法,参数n代表$\!\bmod{x^n}$,且保证调用后前n项均可访问。下表说明了各方法的复杂度。

  • der(求导)/pri(积分):$O(n)$
  • mul/inv/log/exp:$O(n\log n)$
  • div/mod:$O(n\log n)$
  • 在$m$个点处求值:$O(n\log n+m\log^2\min(m,n))$
  • 以根集构造:$O(n\log^2n)$
  • 以点集构造(插值):$O(n\log^2n)$

注意,对于某些算式,相比直接调用以上方法进行计算,使用针对性的实现能使效率大幅提高。


#include<algorithm>
#include<vector>
#define RAN(a)a.begin(),a.end()
using namespace std;
typedef unsigned long long u64;
typedef unsigned u32;
namespace num{
	const u32 p=998244353;
	const u32 g=3;
	inline u32 imod(u32 a){
		return a<p?a:a-p;
	}
	inline u32 ipow(u32 a,u32 n){
		u32 s=1;
		for(;n;n>>=1){
			if(n&1)
				s=(u64)s*a%p;
			a=(u64)a*a%p;
		}
		return s;
	}
	class inv_t{
	public:
		inv_t():f(1,1){}
		u32 operator[](int n){
			int m=f.size();
			if(m<n){
				f.resize(n);
				for(int i=m+1;i<=n;++i)
					f[i-1]=(u64)(p-p/i)*f[p%i-1]%p;
			}
			return f[n-1];
		}
		u32 operator()(u32 a)const{
			return ipow(a,p-2);
		}
	private:
		vector<u32>f;
	}inv;
}
using namespace num;
class poly{
public:
	vector<u32>a;
	poly(){}
	explicit poly(int n):a(n){}
	poly(const vector<u32>&b,int n):
		a(b.begin(),b.begin()+n)
	{}
	poly(const poly&b,int n):
		a(b.a.begin(),b.a.begin()+n)
	{}
	u32&operator[](int i){
		return a[i];
	}
	const u32&operator[](int i)const{
		return a[i];
	}
	int size()const{
		return a.size();
	}
	void swap(poly&b){
		a.swap(b.a);
	}
	void shl(int k=1){
		a.insert(a.begin(),k,0);
	}
	void shr(int k=1){
		a.erase(a.begin(),a.begin()+k);
	}
	void der(){
		fix();
		if(size()){
			for(int i=1;i<size();++i)
				a[i-1]=(u64)i*a[i]%p;
			a.pop_back();
		}
	}
	void pri(){
		fix();
		shl();
		for(int i=size()-1;i>0;--i)
			a[i]=(u64)a[i]*num::inv[i]%p;
	}
	static int len(int n){
		while(n^n&-n)
			n+=n&-n;
		return n;
	}
	void fft(int n,bool f){
		a.resize(n);
		if(n<=1)
			return;
		for(int i=0,j=0;i<n;++i){
			if(i<j)
				std::swap(a[i],a[j]);
			int k=n>>1;
			while((j^=k)<k)
				k>>=1;
		}
		vector<u32>w(n/2);
		w[0]=1;
		for(int i=1;i<n;i<<=1){
			for(int j=i/2-1;~j;--j)
				w[j<<1]=w[j];
			int m=(p-1)/i/2;
			u64 s=ipow(g,f?p-1-m:m);
			for(int j=1;j<i;j+=2)
				w[j]=s*w[j-1]%p;
			for(int j=0;j<n;j+=i<<1){
				u32*b=&a[0]+j,*c=b+i;
				for(int k=0;k<i;++k){
					u32 v=(u64)w[k]*c[k]%p;
					c[k]=imod(b[k]+p-v);
					b[k]=imod(b[k]+v);
				}
			}
		}
	}
	void dft(int n){
		fft(n,0);
	}
	void idft(){
		int n=size();
		fft(n,1);
		u64 f=num::inv(n);
		for(int i=0;i<n;++i)
			a[i]=f*a[i]%p;
	}
	void fix(){
		while(size()&&!a.back())
			a.pop_back();
	}
	void mul_ref(poly&b){
		fix();
		if(&b!=this)
			b.fix();
		int n=len(size()+b.size()-1);
		dft(n);
		if(&b!=this)
			b.dft(n);
		for(int i=0;i<n;++i)
			a[i]=(u64)a[i]*b[i]%p;
		idft();
		fix();
	}
	void mul(poly b){
		mul_ref(b);
	}
	void sqr(){
		mul_ref(*this);
	}
	void imul(u32 k){
		for(int i=0;i<size();++i)
			a[i]=(u64)k*a[i]%p;
	}
	void mod(int n){
		a.resize(n);
	}
	void zero(int n){
		a.clear();
		mod(n);
	}
	void mul(poly b,int n){
		b.mod(n);
		mul_ref(b);
		mod(n);
	}
	void inv(int n){
		int m=len(n);
		mod(m);
		poly f(m);
		swap(f);
		a[0]=num::inv(f[0]);
		for(int i=1;i<m;i<<=1){
			poly s(f,i*2);
			poly t(a,i*2);
			s.dft(i*2);
			t.dft(i*2);
			for(int j=0;j<i*2;++j)
				s[j]=(u64)s[j]*t[j]%p;
			s.idft();
			for(int j=0;j<i;++j)
				s[j]=0;
			s.dft(i*2);
			for(int j=0;j<i*2;++j)
				s[j]=(u64)s[j]*(p-t[j])%p;
			s.idft();
			for(int j=i;j<i*2;++j)
				a[j]=s[j];
		}
		mod(n);
	}
	void inv_mul(poly b,int n){
		if(n==1){
			inv(n);
			return mul(b,n);
		}
		int m=len(n)/2;
		mod(m*2);
		b.mod(m*2);
		poly c(a,m);
		c.inv(m);
		c.dft(m*2);
		poly d(b,m);
		d.dft(m*2);
		for(int j=0;j<m*2;++j)
			d[j]=(u64)d[j]*c[j]%p;
		d.idft();
		poly t(d,m);
		t.dft(m*2);
		dft(m*2);
		for(int j=0;j<m*2;++j)
			a[j]=(u64)a[j]*t[j]%p;
		idft();
		for(int j=0;j<m;++j)
			a[j]=0;
		for(int j=m;j<m*2;++j)
			a[j]=imod(a[j]+p-b[j]);
		dft(m*2);
		for(int j=0;j<m*2;++j)
			a[j]=(u64)a[j]*(p-c[j])%p;
		idft();
		for(int j=0;j<m;++j)
			a[j]=d[j];
		mod(n);
	}
	void log(int n){
		if(n==1)
			return zero(n);
		mod(n);
		poly b=*this;
		b.der();
		inv_mul(b,n-1);
		pri();
		mod(n);
	}
	void exp(int n,poly*r=0){
		int m=len(n);
		mod(m);
		poly f(m);
		swap(f);
		a[0]=1;
		poly b(m);
		b[0]=1;
		poly c(1);
		c[0]=1;
		for(int i=1;i<m;i<<=1){
			poly s(f,i);
			s.der();
			s.dft(i);
			for(int j=0;j<i;++j)
				s[j]=(u64)s[j]*c[j]%p;
			s.idft();
			poly t(a,i);
			t.der();
			t.mod(i-1);
			s.mod(i*2);
			for(int j=0;j<i-1;++j){
				s[j+i]=imod(t[j]+p-s[j]);
				s[j]=0;
			}
			s[i-1]=imod(p-s[i-1]);
			s.dft(i*2);
			t=poly(b,i);
			t.dft(i*2);
			for(int j=0;j<i*2;++j)
				s[j]=(u64)s[j]*t[j]%p;
			s.idft();
			for(int j=0;j<i-1;++j)
				s[j]=0;
			s.pri();
			s.mod(i*2);
			for(int j=i;j<i*2;++j)
				s[j]=imod(s[j]+p-f[j]);
			s.dft(i*2);
			c=poly(a,i);
			c.dft(i*2);
			for(int j=0;j<i*2;++j)
				s[j]=(u64)s[j]*c[j]%p;
			s.idft();
			for(int j=i;j<i*2;++j)
				a[j]=imod(p-s[j]);
			if(!r&&i==m/2)
				break;
			c=poly(a,i*2);
			c.dft(i*2);
			for(int j=0;j<i*2;++j)
				s[j]=(u64)c[j]*t[j]%p;
			s.idft();
			for(int j=0;j<i;++j)
				s[j]=0;
			s.dft(i*2);
			for(int j=0;j<i*2;++j)
				s[j]=(u64)s[j]*(p-t[j])%p;
			s.idft();
			for(int j=i;j<i*2;++j)
				b[j]=s[j];
		}
		mod(n);
		if(r){
			b.mod(n);
			b.swap(*r);
		}
	}
	void pow(u32 k,int n){
		mod(n);
		if(k==0){
			zero(n);
			a[0]=1;
		}else if(a[0]==1){
			log(n);
			imul(k);
			exp(n);
		}else if(a[0]){
			u32 w=num::inv(a[0]);
			u32 v=ipow(a[0],k);
			imul(w);
			pow(k,n);
			imul(v);
		}else{
			for(int i=1;i<n;++i)
				if(a[i]){
					if(i>=n/k){
						zero(n);
					}else{
						shr(i);
						pow(k,n-i*k);
						shl(i*k);
					}
					break;
				}
		}
	}
	poly mod_bf(poly b){
		fix();
		b.fix();
		int n=size();
		int m=b.size();
		if(n<m)
			return poly();
		poly q(n-m+1);
		u64 w=num::inv(b[m-1]);
		for(;n>=m;--n)
			if(u64 s=a[n-1]*w%p){
				q[n-m]=s;
				for(int i=n-1;i>=n-m;--i)
					a[i]=(a[i]+(p-s)*b[i-n+m])%p;
			}
		fix();
		return q;
	}
	void gcd(poly b){
		fix();
		b.fix();
		while(b.size()){
			mod_bf(b);
			swap(b);
		}
	}
	void div(poly b){
		fix();
		b.fix();
		int n=size()-b.size()+1;
		if(n<=0)
			return zero(0);
		reverse(RAN(a));
		mod(n);
		reverse(RAN(b.a));
		swap(b);
		inv_mul(b,n);
		reverse(RAN(a));
		fix();
	}
	void mod(poly b){
		fix();
		b.fix();
		int n=b.size();
		if(size()>=n){
			poly c=*this;
			div(b);
			mul(b,n-1);
			for(int i=0;i<n-1;++i)
				a[i]=imod(c[i]+p-a[i]);
			fix();
		}
	}
	u32 val(u32 x)const{
		u32 s=0;
		for(int i=size()-1;~i;--i)
			s=((u64)s*x+a[i])%p;
		return s;
	}
	u32 operator()(u32 x)const{
		return val(x);
	}
	template<class ite1,class ite2>
	void val(int n,ite1 x,ite2 y)const;
	template<class ite>
	poly(int n,ite x);
	template<class ite1,class ite2>
	poly(int n,ite1 x,ite2 y);
	class seg;
private:
	void mod(poly b,const poly&f);
	template<class ite>
	static poly gen(int n,ite a);
	template<class ite>
	static poly gen(int x,int y,ite&a);
	template<class ite>
	static poly gen_bf(int n,ite&a);
};
void poly::mod(poly b,const poly&f){
	fix();
	b.fix();
	int m=b.size();
	if(size()>=m){
		poly c=*this;
		div(b);
		dft(f.size());
		for(int i=0;i<f.size();++i)
			a[i]=(u64)a[i]*f[i]%p;
		idft();
		for(int i=0;i<m-1;++i)
			a[i]=imod(c[i]+p-a[i]);
		mod(m-1);
		fix();
	}
}
template<class ite>
poly::poly(int n,ite x){
	*this=gen(n,x);
}
template<class ite>
poly poly::gen(int n,ite a){
	return gen(0,n,a);
}
template<class ite>
poly poly::gen(int x,int y,ite&a){
	if(y-x<=200){
		return gen_bf(y-x,a);
	}else{
		int m=x+y>>1;
		poly b=gen(x,m,a);
		b.mul(gen(m,y,a));
		return b;
	}
}
template<class ite>
poly poly::gen_bf(int n,ite&a){
	poly f(1);
	f[0]=1;
	for(int i=0;i<n;++i){
		f.shl();
		u64 s=p-*a++;
		for(int j=0;j<=i;++j)
			f[j]=(f[j]+s*f[j+1])%p;
	}
	return f;
}
class poly::seg{
public:
	template<class ite>
	seg(int n,const ite&x):a(n){
		ite t=x;
		for(int i=0;i<n;++i)
			a[i]=*t++;
		dfs(0,n,1);
	}
	poly gen()const{
		return t[1].a;
	}
	template<class ite>
	poly gen(const ite&b)const{
		ite t=b;
		return gen(t);
	}
	template<class ite>
	void val(const ite&b,const poly&f)const{
		ite t=b;
		val(t,f);
	}
	template<class ite>
	seg(int n,ite&x):a(n){
		for(int i=0;i<n;++i)
			a[i]=*x++;
		dfs(0,n,1);
	}
	template<class ite>
	poly gen(ite&b)const{
		poly f=t[1].a;
		f.der();
		return gen(0,a.size(),1,b,f);
	}
	template<class ite>
	void val(ite&b,const poly&f)const{
		val(0,a.size(),1,b,f);
	}
private:
	struct pair{
		poly a,b;
	};
	vector<pair>t;
	vector<u32>a;
	void dfs(int x,int y,int k){
		if(t.size()<=k)
			t.resize(k+1);
		if(y-x<=200){
			t[k].a=poly::gen(y-x,a.begin()+x);
		}else{
			int m=x+y>>1,i=k<<1,j=i^1;
			dfs(x,m,i);
			dfs(m,y,j);
			int n=len(y-x+1);
			t[i].b=t[i].a;
			t[i].b.dft(n);
			t[j].b=t[j].a;
			t[j].b.dft(n);
			t[k].a=t[i].b;
			for(int l=0;l<n;++l)
				t[k].a[l]=(u64)t[k].a[l]*t[j].b[l]%p;
			t[k].a.idft();
		}
	}
	template<class ite>
	void val(int x,int y,int k,ite&b,poly f)const{
		if(k!=1)
			f.mod(t[k].a,t[k].b);
		else
			f.mod(t[k].a);
		if(y-x<=200){
			for(int i=0;i<y-x;++i)
				*b++=f(a[x+i]);
		}else{
			int m=x+y>>1,i=k<<1,j=i^1;
			val(x,m,i,b,f);
			val(m,y,j,b,f);
		}
	}
	template<class ite>
	poly gen(int x,int y,int k,ite&b,poly f)const{
		if(k!=1)
			f.mod(t[k].a,t[k].b);
		else
			f.mod(t[k].a);
		if(y-x<=200){
			vector<u64>c(y-x);
			for(int i=0;i<y-x;++i){
				u64 n=a[x+i];
				u64 m=*b++;
				m=m*num::inv(f(n))%p;
				u64 s=0;
				for(int j=y-x;j>0;--j){
					s=(t[k].a[j]+s*n)%p;
					c[j-1]+=s*m%p;
				}
			}
			poly d(y-x);
			for(int i=0;i<y-x;++i)
				d[i]=c[i]%p;
			return d;
		}else{
			int m=x+y>>1;
			int i=k<<1;
			int j=i^1;
			poly c=gen(x,m,i,b,f);
			poly d=gen(m,y,j,b,f);
			int n=len(y-x+1);
			c.dft(n);
			d.dft(n);
			for(int l=0;l<n;++l)
				c[l]=((u64)c[l]*t[j].b[l]+(u64)d[l]*t[i].b[l])%p;
			c.idft();
			return c;
		}
	}
};
template<class ite1,class ite2>
poly::poly(int n,ite1 x,ite2 y){
	*this=seg(n,x).gen(y);
}
template<class ite1,class ite2>
void poly::val(int n,ite1 x,ite2 y)const{
	int m=size();
	if(min(m,n)<=100)
		for(int i=0;i<n;++i)
			*y++=val(*x++);
	else
		for(int l=n;;l/=2)
			if(l*2<=m){
				for(int i=0;i<n;i+=l)
					seg(min(l,n-i),x).val(y,*this);
				break;
			}
}

2018-08-14

  • 计划今后增加$O(n\log^2n)$求多项式gcd的算法。

2018-08-18

  • 增加了$O(n^2)$的gcd,未保证结果为首一多项式。
  • 分离了基础部分和点值有关部分。

2018-08-19

  • 为避免影响编译器优化,删去了在运行时确定模数的功能。

2019-03-17

  • 增加了sin/cos/tan。

2019-06-28

  • 修复了除法中非预期地降低效率的问题。

2021-08-07

由于有道题被卡常数了,应用这篇博客中的部分技巧,大幅优化了主要的形式幂级数方法的常数。

  • 增加了inv_mul方法,比先inv再mul更快。
  • exp方法新增了可选的指针参数,若这个参数不为空指针,会将exp的逆保存到其中,比单独求逆更快。
  • 增加了用log和exp实现的pow方法。

以长度为n的DFT次数为单位,乘法的常数为6,下表说明了各方法的优化效果。

  • inv:12→10
  • inv_mul:18→13
  • log:18→13
  • exp:48→18
  • exp+inv_exp:60→22
  • pow:66→31

删去了没什么意义的三角函数。删去了未经优化的sqrt方法,现在通过pow来计算sqrt和原来的sqrt效率差不多。原来的模板可以在这里找到。

posted @ 2018-07-25 00:03  f321dd  阅读(737)  评论(0编辑  收藏  举报