多项式模板
这里给出了一个多项式/形式幂级数的类的实现。由于仅考虑形式幂级数的前n项,即$\!\bmod{x^n}$下的等价类,故其形式与多项式相同,因而在类的实现上没有对两者作区分。没有返回值的方法会将结果用于更新自身。对于所有用于形式幂级数的方法,参数n代表$\!\bmod{x^n}$,且保证调用后前n项均可访问。下表说明了各方法的复杂度。
- der(求导)/pri(积分):$O(n)$
- mul/inv/log/exp:$O(n\log n)$
- div/mod:$O(n\log n)$
- 在$m$个点处求值:$O(n\log n+m\log^2\min(m,n))$
- 以根集构造:$O(n\log^2n)$
- 以点集构造(插值):$O(n\log^2n)$
注意,对于某些算式,相比直接调用以上方法进行计算,使用针对性的实现能使效率大幅提高。
#include<algorithm> #include<vector> #define RAN(a)a.begin(),a.end() using namespace std; typedef unsigned long long u64; typedef unsigned u32; namespace num{ const u32 p=998244353; const u32 g=3; inline u32 imod(u32 a){ return a<p?a:a-p; } inline u32 ipow(u32 a,u32 n){ u32 s=1; for(;n;n>>=1){ if(n&1) s=(u64)s*a%p; a=(u64)a*a%p; } return s; } class inv_t{ public: inv_t():f(1,1){} u32 operator[](int n){ int m=f.size(); if(m<n){ f.resize(n); for(int i=m+1;i<=n;++i) f[i-1]=(u64)(p-p/i)*f[p%i-1]%p; } return f[n-1]; } u32 operator()(u32 a)const{ return ipow(a,p-2); } private: vector<u32>f; }inv; } using namespace num; class poly{ public: vector<u32>a; poly(){} explicit poly(int n):a(n){} poly(const vector<u32>&b,int n): a(b.begin(),b.begin()+n) {} poly(const poly&b,int n): a(b.a.begin(),b.a.begin()+n) {} u32&operator[](int i){ return a[i]; } const u32&operator[](int i)const{ return a[i]; } int size()const{ return a.size(); } void swap(poly&b){ a.swap(b.a); } void shl(int k=1){ a.insert(a.begin(),k,0); } void shr(int k=1){ a.erase(a.begin(),a.begin()+k); } void der(){ fix(); if(size()){ for(int i=1;i<size();++i) a[i-1]=(u64)i*a[i]%p; a.pop_back(); } } void pri(){ fix(); shl(); for(int i=size()-1;i>0;--i) a[i]=(u64)a[i]*num::inv[i]%p; } static int len(int n){ while(n^n&-n) n+=n&-n; return n; } void fft(int n,bool f){ a.resize(n); if(n<=1) return; for(int i=0,j=0;i<n;++i){ if(i<j) std::swap(a[i],a[j]); int k=n>>1; while((j^=k)<k) k>>=1; } vector<u32>w(n/2); w[0]=1; for(int i=1;i<n;i<<=1){ for(int j=i/2-1;~j;--j) w[j<<1]=w[j]; int m=(p-1)/i/2; u64 s=ipow(g,f?p-1-m:m); for(int j=1;j<i;j+=2) w[j]=s*w[j-1]%p; for(int j=0;j<n;j+=i<<1){ u32*b=&a[0]+j,*c=b+i; for(int k=0;k<i;++k){ u32 v=(u64)w[k]*c[k]%p; c[k]=imod(b[k]+p-v); b[k]=imod(b[k]+v); } } } } void dft(int n){ fft(n,0); } void idft(){ int n=size(); fft(n,1); u64 f=num::inv(n); for(int i=0;i<n;++i) a[i]=f*a[i]%p; } void fix(){ while(size()&&!a.back()) a.pop_back(); } void mul_ref(poly&b){ fix(); if(&b!=this) b.fix(); int n=len(size()+b.size()-1); dft(n); if(&b!=this) b.dft(n); for(int i=0;i<n;++i) a[i]=(u64)a[i]*b[i]%p; idft(); fix(); } void mul(poly b){ mul_ref(b); } void sqr(){ mul_ref(*this); } void imul(u32 k){ for(int i=0;i<size();++i) a[i]=(u64)k*a[i]%p; } void mod(int n){ a.resize(n); } void zero(int n){ a.clear(); mod(n); } void mul(poly b,int n){ b.mod(n); mul_ref(b); mod(n); } void inv(int n){ int m=len(n); mod(m); poly f(m); swap(f); a[0]=num::inv(f[0]); for(int i=1;i<m;i<<=1){ poly s(f,i*2); poly t(a,i*2); s.dft(i*2); t.dft(i*2); for(int j=0;j<i*2;++j) s[j]=(u64)s[j]*t[j]%p; s.idft(); for(int j=0;j<i;++j) s[j]=0; s.dft(i*2); for(int j=0;j<i*2;++j) s[j]=(u64)s[j]*(p-t[j])%p; s.idft(); for(int j=i;j<i*2;++j) a[j]=s[j]; } mod(n); } void inv_mul(poly b,int n){ if(n==1){ inv(n); return mul(b,n); } int m=len(n)/2; mod(m*2); b.mod(m*2); poly c(a,m); c.inv(m); c.dft(m*2); poly d(b,m); d.dft(m*2); for(int j=0;j<m*2;++j) d[j]=(u64)d[j]*c[j]%p; d.idft(); poly t(d,m); t.dft(m*2); dft(m*2); for(int j=0;j<m*2;++j) a[j]=(u64)a[j]*t[j]%p; idft(); for(int j=0;j<m;++j) a[j]=0; for(int j=m;j<m*2;++j) a[j]=imod(a[j]+p-b[j]); dft(m*2); for(int j=0;j<m*2;++j) a[j]=(u64)a[j]*(p-c[j])%p; idft(); for(int j=0;j<m;++j) a[j]=d[j]; mod(n); } void log(int n){ if(n==1) return zero(n); mod(n); poly b=*this; b.der(); inv_mul(b,n-1); pri(); mod(n); } void exp(int n,poly*r=0){ int m=len(n); mod(m); poly f(m); swap(f); a[0]=1; poly b(m); b[0]=1; poly c(1); c[0]=1; for(int i=1;i<m;i<<=1){ poly s(f,i); s.der(); s.dft(i); for(int j=0;j<i;++j) s[j]=(u64)s[j]*c[j]%p; s.idft(); poly t(a,i); t.der(); t.mod(i-1); s.mod(i*2); for(int j=0;j<i-1;++j){ s[j+i]=imod(t[j]+p-s[j]); s[j]=0; } s[i-1]=imod(p-s[i-1]); s.dft(i*2); t=poly(b,i); t.dft(i*2); for(int j=0;j<i*2;++j) s[j]=(u64)s[j]*t[j]%p; s.idft(); for(int j=0;j<i-1;++j) s[j]=0; s.pri(); s.mod(i*2); for(int j=i;j<i*2;++j) s[j]=imod(s[j]+p-f[j]); s.dft(i*2); c=poly(a,i); c.dft(i*2); for(int j=0;j<i*2;++j) s[j]=(u64)s[j]*c[j]%p; s.idft(); for(int j=i;j<i*2;++j) a[j]=imod(p-s[j]); if(!r&&i==m/2) break; c=poly(a,i*2); c.dft(i*2); for(int j=0;j<i*2;++j) s[j]=(u64)c[j]*t[j]%p; s.idft(); for(int j=0;j<i;++j) s[j]=0; s.dft(i*2); for(int j=0;j<i*2;++j) s[j]=(u64)s[j]*(p-t[j])%p; s.idft(); for(int j=i;j<i*2;++j) b[j]=s[j]; } mod(n); if(r){ b.mod(n); b.swap(*r); } } void pow(u32 k,int n){ mod(n); if(k==0){ zero(n); a[0]=1; }else if(a[0]==1){ log(n); imul(k); exp(n); }else if(a[0]){ u32 w=num::inv(a[0]); u32 v=ipow(a[0],k); imul(w); pow(k,n); imul(v); }else{ for(int i=1;i<n;++i) if(a[i]){ if(i>=n/k){ zero(n); }else{ shr(i); pow(k,n-i*k); shl(i*k); } break; } } } poly mod_bf(poly b){ fix(); b.fix(); int n=size(); int m=b.size(); if(n<m) return poly(); poly q(n-m+1); u64 w=num::inv(b[m-1]); for(;n>=m;--n) if(u64 s=a[n-1]*w%p){ q[n-m]=s; for(int i=n-1;i>=n-m;--i) a[i]=(a[i]+(p-s)*b[i-n+m])%p; } fix(); return q; } void gcd(poly b){ fix(); b.fix(); while(b.size()){ mod_bf(b); swap(b); } } void div(poly b){ fix(); b.fix(); int n=size()-b.size()+1; if(n<=0) return zero(0); reverse(RAN(a)); mod(n); reverse(RAN(b.a)); swap(b); inv_mul(b,n); reverse(RAN(a)); fix(); } void mod(poly b){ fix(); b.fix(); int n=b.size(); if(size()>=n){ poly c=*this; div(b); mul(b,n-1); for(int i=0;i<n-1;++i) a[i]=imod(c[i]+p-a[i]); fix(); } } u32 val(u32 x)const{ u32 s=0; for(int i=size()-1;~i;--i) s=((u64)s*x+a[i])%p; return s; } u32 operator()(u32 x)const{ return val(x); } template<class ite1,class ite2> void val(int n,ite1 x,ite2 y)const; template<class ite> poly(int n,ite x); template<class ite1,class ite2> poly(int n,ite1 x,ite2 y); class seg; private: void mod(poly b,const poly&f); template<class ite> static poly gen(int n,ite a); template<class ite> static poly gen(int x,int y,ite&a); template<class ite> static poly gen_bf(int n,ite&a); }; void poly::mod(poly b,const poly&f){ fix(); b.fix(); int m=b.size(); if(size()>=m){ poly c=*this; div(b); dft(f.size()); for(int i=0;i<f.size();++i) a[i]=(u64)a[i]*f[i]%p; idft(); for(int i=0;i<m-1;++i) a[i]=imod(c[i]+p-a[i]); mod(m-1); fix(); } } template<class ite> poly::poly(int n,ite x){ *this=gen(n,x); } template<class ite> poly poly::gen(int n,ite a){ return gen(0,n,a); } template<class ite> poly poly::gen(int x,int y,ite&a){ if(y-x<=200){ return gen_bf(y-x,a); }else{ int m=x+y>>1; poly b=gen(x,m,a); b.mul(gen(m,y,a)); return b; } } template<class ite> poly poly::gen_bf(int n,ite&a){ poly f(1); f[0]=1; for(int i=0;i<n;++i){ f.shl(); u64 s=p-*a++; for(int j=0;j<=i;++j) f[j]=(f[j]+s*f[j+1])%p; } return f; } class poly::seg{ public: template<class ite> seg(int n,const ite&x):a(n){ ite t=x; for(int i=0;i<n;++i) a[i]=*t++; dfs(0,n,1); } poly gen()const{ return t[1].a; } template<class ite> poly gen(const ite&b)const{ ite t=b; return gen(t); } template<class ite> void val(const ite&b,const poly&f)const{ ite t=b; val(t,f); } template<class ite> seg(int n,ite&x):a(n){ for(int i=0;i<n;++i) a[i]=*x++; dfs(0,n,1); } template<class ite> poly gen(ite&b)const{ poly f=t[1].a; f.der(); return gen(0,a.size(),1,b,f); } template<class ite> void val(ite&b,const poly&f)const{ val(0,a.size(),1,b,f); } private: struct pair{ poly a,b; }; vector<pair>t; vector<u32>a; void dfs(int x,int y,int k){ if(t.size()<=k) t.resize(k+1); if(y-x<=200){ t[k].a=poly::gen(y-x,a.begin()+x); }else{ int m=x+y>>1,i=k<<1,j=i^1; dfs(x,m,i); dfs(m,y,j); int n=len(y-x+1); t[i].b=t[i].a; t[i].b.dft(n); t[j].b=t[j].a; t[j].b.dft(n); t[k].a=t[i].b; for(int l=0;l<n;++l) t[k].a[l]=(u64)t[k].a[l]*t[j].b[l]%p; t[k].a.idft(); } } template<class ite> void val(int x,int y,int k,ite&b,poly f)const{ if(k!=1) f.mod(t[k].a,t[k].b); else f.mod(t[k].a); if(y-x<=200){ for(int i=0;i<y-x;++i) *b++=f(a[x+i]); }else{ int m=x+y>>1,i=k<<1,j=i^1; val(x,m,i,b,f); val(m,y,j,b,f); } } template<class ite> poly gen(int x,int y,int k,ite&b,poly f)const{ if(k!=1) f.mod(t[k].a,t[k].b); else f.mod(t[k].a); if(y-x<=200){ vector<u64>c(y-x); for(int i=0;i<y-x;++i){ u64 n=a[x+i]; u64 m=*b++; m=m*num::inv(f(n))%p; u64 s=0; for(int j=y-x;j>0;--j){ s=(t[k].a[j]+s*n)%p; c[j-1]+=s*m%p; } } poly d(y-x); for(int i=0;i<y-x;++i) d[i]=c[i]%p; return d; }else{ int m=x+y>>1; int i=k<<1; int j=i^1; poly c=gen(x,m,i,b,f); poly d=gen(m,y,j,b,f); int n=len(y-x+1); c.dft(n); d.dft(n); for(int l=0;l<n;++l) c[l]=((u64)c[l]*t[j].b[l]+(u64)d[l]*t[i].b[l])%p; c.idft(); return c; } } }; template<class ite1,class ite2> poly::poly(int n,ite1 x,ite2 y){ *this=seg(n,x).gen(y); } template<class ite1,class ite2> void poly::val(int n,ite1 x,ite2 y)const{ int m=size(); if(min(m,n)<=100) for(int i=0;i<n;++i) *y++=val(*x++); else for(int l=n;;l/=2) if(l*2<=m){ for(int i=0;i<n;i+=l) seg(min(l,n-i),x).val(y,*this); break; } }
2018-08-14
- 计划今后增加$O(n\log^2n)$求多项式gcd的算法。
2018-08-18
- 增加了$O(n^2)$的gcd,未保证结果为首一多项式。
- 分离了基础部分和点值有关部分。
2018-08-19
- 为避免影响编译器优化,删去了在运行时确定模数的功能。
2019-03-17
- 增加了sin/cos/tan。
2019-06-28
- 修复了除法中非预期地降低效率的问题。
2021-08-07
由于有道题被卡常数了,应用这篇博客中的部分技巧,大幅优化了主要的形式幂级数方法的常数。
- 增加了inv_mul方法,比先inv再mul更快。
- exp方法新增了可选的指针参数,若这个参数不为空指针,会将exp的逆保存到其中,比单独求逆更快。
- 增加了用log和exp实现的pow方法。
以长度为n的DFT次数为单位,乘法的常数为6,下表说明了各方法的优化效果。
- inv:12→10
- inv_mul:18→13
- log:18→13
- exp:48→18
- exp+inv_exp:60→22
- pow:66→31
删去了没什么意义的三角函数。删去了未经优化的sqrt方法,现在通过pow来计算sqrt和原来的sqrt效率差不多。原来的模板可以在这里找到。