多项式各种板子的快速实现
- 使用
memset
和memcpy
,对于long long
类型的数组,memset(a,0,n<<3)
可以清空前 \(n\) 项,memcpy(A,a,n<<3)
可以 复制前 \(n\) 项,比用for
更快更简洁。
多项式乘法 | 快速傅里叶变换 fft
blogs
update
- 使用了 进位器 代替
rev[]
,均摊线性,好处是不用每次预处理rev[]
。 - IDFT /n 的操作改为在内部完成。
typedef complex<double> cp;
const int N=4e5+5; // 开4倍
const double PI=acos(-1);
void fft(cp a[],int n,int inv)
{
for(int i=0,j=0;i<n;i++) {
if(i<j) swap(a[i],a[j]);
for(int l=(n>>1);(j^=l)<l;l>>=1);
}
for(int k=2,m=1;k<=n;m=k,k<<=1) {
cp wn(cos(2*PI/k),inv*sin(2*PI/k));
for(int i=0;i<n;i+=k) {
cp w(1,0);
for(int j=0;j<m;j++,w*=wn) {
cp u=a[i+j],v=a[i+j+m];
a[i+j]=u+w*v,a[i+j+m]=u-w*v;
}
}
}
if(inv==1) return;
for(int i=0;i<n;i++) a[i].real(a[i].real()/n);
}
多项式乘法 | 快速数论变换 ntt
const LL P=998244353,G=3;
LL power(LL x,LL k,LL MOD)
{
LL res=1; x%=MOD;
while(k) {
if(k&1) res=res*x%MOD;
x=x*x%MOD; k>>=1;
}
return res%MOD;
}
void ntt(LL a[],int n,int inv)
{
for(int i=0,j=0;i<n;i++) {
if(i<j) swap(a[i],a[j]);
for(int l=(n>>1);(j^=l)<l;l>>=1);
}
for(int m=1,k=2;k<=n;m=k,k<<=1) {
LL gn=power(G,(P-1)/k,P);
if(inv==-1) gn=power(gn,P-2,P);
for(int i=0;i<n;i+=k) {
LL g=1;
for(int j=0;j<m;j++,g=g*gn%P) {
LL u=a[i+j],v=a[i+j+m];
a[i+j]=u+g*v; a[i+j+m]=u-g*v;
}
}
for(int i=0;i<n;i++) a[i]=(a[i]%P+P)%P;
}
if(inv==1) return;
LL invn=power(n,P-2,P);
for(int i=0;i<n;i++) a[i]=a[i]*invn%P;
}
在乘数较大的时候不能用 fft,在模数 \(P\) 没有原根或者 \(P-1\) 不是 \(2\) 的高次幂的倍数的时候不能用 ntt.
任意模数多项式乘法 | mtt
- 注意使用
long double
,避免被卡精度。 - 项数为 \(n\) 和项数为 \(m\) 的多项式相乘,最高次幂为 \(x^{n+m-2}\) ,共有 \(n+m-1\) 项,故
while(lim<n+n-1) lim<<=1;
。
void mtt(LL a[],LL b[],LL c[],int n)
{
static cp A[N],B[N],C[N],D[N],E[N],F[N];
static const LL M=(1ll<<15);
int lim=1;
while(lim<n+n-1) lim<<=1;
for(int i=0;i<lim;i++) A[i]=B[i]=C[i]=D[i]=cp(0,0);
for(int i=0;i<lim;i++) A[i].real(a[i]/M),B[i].real(a[i]%M);
for(int i=0;i<lim;i++) C[i].real(b[i]/M),D[i].real(b[i]%M);
fft(A,lim,1); fft(B,lim,1); fft(C,lim,1); fft(D,lim,1);
for(int i=0;i<lim;i++) E[i]=A[i]*C[i],F[i]=B[i]*D[i];
fft(E,lim,-1); fft(F,lim,-1);
for(int i=0;i<lim;i++) c[i]=((LL)round(E[i].real()))%MOD*M%MOD*M%MOD+(LL)round(F[i].real());
for(int i=0;i<lim;i++) c[i]%=MOD;
for(int i=0;i<lim;i++) E[i]=A[i]*D[i]+B[i]*C[i];
fft(E,lim,-1);
for(int i=0;i<lim;i++) c[i]=(c[i]+((LL)round(E[i].real()))%MOD*M%MOD)%MOD;
}
任意模数多项式乘法 | 三模 ntt
- 三个模数 \(998244353,469762049,1004535809\)。
typedef long long LL;
const int N=4e5+5;
const LL Ps[]={998244353,469762049,1004535809},G=3;
LL muler(LL x,LL y,LL MOD)
{
x=(x%MOD+MOD)%MOD; y=(y%MOD+MOD)%MOD;
LL high=(long double)x/MOD*y;
LL low=x*y-high*MOD;
return (low%MOD+MOD)%MOD;
}
LL power(LL x,LL k,LL MOD)
{
LL res=1; x%=MOD;
while(k) {
if(k&1) res=res*x%MOD;
x=x*x%MOD; k>>=1;
}
return res%MOD;
}
void ntt(LL a[],int n,int inv,LL P)
{
for(int i=0,j=0;i<n;i++) {
if(i<j) swap(a[i],a[j]);
for(int l=(n>>1);(j^=l)<l;l>>=1);
}
for(int m=1,k=2;k<=n;m=k,k<<=1) {
LL gn=power(G,(P-1)/k,P);
if(inv==-1) gn=power(gn,P-2,P);
for(int i=0;i<n;i+=k) {
LL g=1;
for(int j=0;j<m;j++,g=g*gn%P) {
LL u=a[i+j],v=a[i+j+m];
a[i+j]=u+g*v; a[i+j+m]=u-g*v;
}
}
for(int i=0;i<n;i++) a[i]=(a[i]%P+P)%P;
}
if(inv==1) return;
LL invn=power(n,P-2,P);
for(int i=0;i<n;i++) a[i]=a[i]*invn%P;
}
void TMntt(LL a[],LL b[],LL c[],int n,LL MOD)
{
static LL A[N],B[N],E[N],F[N],Ans[3][N];
static const LL Inv0=power(Ps[1],Ps[0]-2,Ps[0]),Inv1=power(Ps[0],Ps[1]-2,Ps[1]);
static const LL M=Ps[0]*Ps[1],InvM=power(M,Ps[2]-2,Ps[2]);
int lim=1;
while(lim<n+n-1) lim<<=1;
memset(A,0,lim<<3); memset(B,0,lim<<3);
memcpy(A,a,n<<3); memcpy(B,b,n<<3);
for(int p=0;p<3;p++) {
for(int i=0;i<lim;i++) E[i]=A[i]%Ps[p],F[i]=B[i]%Ps[p];
ntt(E,lim,1,Ps[p]); ntt(F,lim,1,Ps[p]);
for(int i=0;i<lim;i++) Ans[p][i]=(E[i]*F[i])%Ps[p];
ntt(Ans[p],lim,-1,Ps[p]);
}
for(int i=0;i<lim;i++) {
LL x=(muler(Ans[0][i]*Ps[1],Inv0,M)+muler(Ans[1][i]*Ps[0],Inv1,M))%M;
LL y=muler(Ans[2][i]-x,InvM,Ps[2]);
c[i]=(muler(y,M,MOD)+x)%MOD;
}
}
多项式牛顿迭代 Newton’s Method
牛顿迭代公式
\[F(x)=F_0(x)-{G(F_0(x))\over G'(F_0(x))} \pmod {x ^n}
\]
一般过程
已知 \(G(x) \bmod x^n\) ,求 \(F(x)\) 使得 \(G(F(x))=0 \bmod x^n\)
- 当 \(n=1\) 时,答案可以直接得出。
- 递归求解 $F_0(x) \bmod x ^{\lceil {n \over 2} \rceil} $ ,使得 \(G(F_0(x))=0 \pmod {x ^{\lceil {n \over 2} \rceil}}\)。
- 已经求出了 \(F_0(x)\),\(F(x)=F_0(x)-{G(F_0(x))\over G'(F_0(x))} \pmod {x ^n}\)
多项式求逆 inv
void inv(LL a[],LL b[],rint n)
{
static LL A[N],B[N];
if(n==1)
return void(b[0]=power(a[0],P-2,P));
int len=(n+1)>>1;
inv(a,b,len);
int lim=1;
while(lim<n+n-1) lim<<=1;
memset(A,0,lim<<3); // 其他项要清零。
memset(B,0,lim<<3);
memcpy(A,a,n<<3); // 注意是在 mod x^n 前提下的,只要 copy 前 n 项。
memcpy(B,b,len<<3);
ntt(B,lim,1); ntt(A,lim,1);
for(int i=0;i<lim;i++) B[i]=(2+P-B[i]*A[i]%P)*B[i]%P;
ntt(B,lim,-1);
memcpy(b,B,n<<3);
}
多项式开根 sqrt
void sqrt(LL a[],LL b[],int n)
{
static LL A[N],B[N],C[N],D[N];
if(n==1)
return void(b[0]=power(G,bsgs(G,a[0],P)>>1,P));
int len=(n+1)>>1;
sqrt(a,b,len);
int lim=1;
while(lim<n+n-1) lim<<=1;
memset(A,0,lim<<3);
memset(B,0,lim<<3);
memset(C,0,lim<<3);
memcpy(B,b,len<<3);
memcpy(A,a,n<<3);
for(int i=0;i<len;i++) C[i]=b[i]+b[i];
inv(C,D,n);
ntt(B,lim,1); ntt(A,lim,1); ntt(D,lim,1);
for(int i=0;i<lim;i++) B[i]=(B[i]*B[i]%P+A[i])%P*D[i]%P;
ntt(B,lim,-1);
memcpy(b,B,n<<3);
}
多项式求导 derivative
void derivative(LL a[],int n)
{
for(int i=0;i<n-1;i++) a[i]=a[i+1]*(i+1)%P;
a[n-1]=0;
}
多项式不定积分 inter
void inter(LL a[],int n)
{
for(int i=n-1;i>=1;i--) a[i]=a[i-1]*power(i,P-2,P)%P;
a[0]=0;
}
多项式对数 ln
void ln(LL a[],int n)
{
static LL A[N],B[N];
int lim=1;
while(lim<n+n-1) lim<<=1;
memset(A,0,lim<<3); memset(B,0,lim<<3);
memcpy(B,a,n<<3); inv(B,A,n);
memset(B,0,lim<<3); memcpy(B,a,n<<3); derivative(B,n);
ntt(A,lim,1); ntt(B,lim,1);
for(int i=0;i<lim;i++) A[i]=A[i]*B[i]%P;
ntt(A,lim,-1);
inter(A,n);
memcpy(a,A,n<<3);
}
多项式指数 exp
本质是 牛顿迭代。
void exp(LL a[],LL b[],int n)
{
static LL A[N],B[N];
if(n==1) return void(b[0]=1);
int len=(n+1)>>1;
exp(a,b,len);
int lim=1;
while(lim<n+n-1) lim<<=1;
memset(A,0,lim<<3); memcpy(A,b,len<<3);
memset(B,0,lim<<3); memcpy(B,b,len<<3);
ln(B,n);
for(int i=0;i<lim;i++) B[i]=(i==0)-B[i]+a[i]*(i<n);
ntt(A,lim,1); ntt(B,lim,1);
for(int i=0;i<lim;i++) B[i]=A[i]*B[i]%P;
ntt(B,lim,-1);
memcpy(b,B,n<<3);
}
多项式快速幂 pow
\(B(x)=\exp (k \ln A(x))\)
关于 \(k\) 为什么可以取模,因为
多项式 vector 封装版
- 封装好的 多项式板子,可能比较慢。
#define SZ(v) ((int)(v).size())
typedef long long LL;
typedef vector<LL> Poly;
const int N=8e5+5;
const LL P=998244353,G=3;
LL bsgs(LL a,LL b,LL p) // a^x = b (mod p)
{
if(b%p==1%p) return 0;
map<LL,LL> mp;
LL k=sqrt(p)+1,ak=1;
for(LL i=0,j=b;i<k;i++,j=j*a%p) mp[j]=i;
for(LL i=1;i<=k;i++) ak=ak*a%p;
for(LL i=1,j=ak;i<=k;i++,j=j*ak%p)
if(mp.count(j)) return i*k-mp[j];
return -1;
}
LL power(LL x,LL k,LL MOD)
{
LL res=1; x%=MOD;
while(k) {
if(k&1) res=res*x%MOD;
x=x*x%MOD; k>>=1;
}
return res%MOD;
}
void ntt(LL a[],int n,int inv)
{
for(int i=0,j=0;i<n;i++) {
if(i<j) swap(a[i],a[j]);
for(int l=(n>>1);(j^=l)<l;l>>=1);
}
for(int m=1,k=2;k<=n;m=k,k<<=1) {
LL gn=power(G,(P-1)/k,P);
if(inv==-1) gn=power(gn,P-2,P);
for(int i=0;i<n;i+=k) {
LL g=1;
for(int j=0;j<m;j++,g=g*gn%P) {
LL u=a[i+j],v=a[i+j+m];
a[i+j]=u+g*v; a[i+j+m]=u-g*v;
}
}
for(int i=0;i<n;i++) a[i]=(a[i]%P+P)%P;
}
if(inv==1) return;
LL invn=power(n,P-2,P);
for(int i=0;i<n;i++) a[i]=a[i]*invn%P;
}
Poly operator*(Poly a,Poly b)
{
static LL A[N],B[N],C[N];
int lim=1;
while(lim<SZ(a)+SZ(b)-1) lim<<=1;
for(int i=0;i<SZ(a);i++) A[i]=a[i];
for(int i=SZ(a);i<lim;i++) A[i]=0;
for(int i=0;i<SZ(b);i++) B[i]=b[i];
for(int i=SZ(b);i<lim;i++) B[i]=0;
ntt(A,lim,1); ntt(B,lim,1);
for(int i=0;i<lim;i++) C[i]=A[i]*B[i]%P;
ntt(C,lim,-1);
Poly c;
for(int i=0;i<SZ(a)+SZ(b)-1;i++)
c.push_back(C[i]);
return c;
}
Poly operator+(Poly a,Poly b)
{
while(SZ(a)<SZ(b)) a.push_back(0);
while(SZ(b)<SZ(a)) b.push_back(0);
for(int i=0;i<SZ(a);i++) a[i]=(a[i]%P+b[i]%P+P+P)%P;
return a;
}
Poly operator-(Poly a,Poly b)
{
while(SZ(a)<SZ(b)) a.push_back(0);
while(SZ(b)<SZ(a)) b.push_back(0);
for(int i=0;i<SZ(a);i++) a[i]=(a[i]%P-b[i]%P+P+P)%P;
return a;
}
Poly operator*(Poly a,LL b)
{
b=(b%P+P)%P;
for(int i=0;i<SZ(a);i++) a[i]=a[i]*b%P;
return a;
}
Poly operator*(LL b,Poly a) { return a*b; }
Poly inv(Poly a)
{
if(SZ(a)==1)
return Poly(1,power(a[0],P-2,P));
int len=(SZ(a)+1)>>1;
Poly b=a; b.resize(len);
b=inv(b);
b=b+b-b*b*a; b.resize(SZ(a));
return b;
}
Poly sqrt(Poly a)
{
if(SZ(a)==1)
return Poly(1,power(G,bsgs(G,a[0],P)>>1,P));
int len=(SZ(a)+1)>>1;
Poly b=a; b.resize(len);
b=sqrt(b); b.resize(SZ(a),0); // 一定注意这里要 resize 一下,保证求的 inv 是在 mod x^n 意义下的。
b=(a+b*b)*inv(b+b); b.resize(SZ(a));
return b;
}
Poly derivative(Poly a)
{
for(int i=0;i<SZ(a)-1;i++) a[i]=a[i+1]*(i+1)%P;
a[SZ(a)-1]=0;
return a;
}
Poly inter(Poly a)
{
for(int i=SZ(a)-1;i>=1;i--) a[i]=a[i-1]*power(i,P-2,P)%P;
a[0]=0;
return a;
}
Poly ln(Poly a)
{
Poly res=inter(derivative(a)*inv(a));
res.resize(SZ(a));
return res;
}
Poly exp(Poly a)
{
if(SZ(a)==1) return Poly(1,1);
int len=(SZ(a)+1)>>1;
Poly b=a; b.resize(len);
b=exp(b); b.resize(SZ(a),0);
b=b*(Poly(1,1)-ln(b)+a); b.resize(SZ(a));
return b;
}
Poly pow(Poly a,LL k) { return exp(k*ln(a)); }
ostream &operator<<(ostream& cout,Poly v)
{
for(int i=0;i<SZ(v);i++) printf("%lld ",v[i]);
return cout;
}
istream &operator>>(istream& cin,Poly& v)
{
int n; LL x;
v.clear();
scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%lld",&x),v.push_back(x);
return cin;
}
废弃区:(一堆伪证)
证明当 \(n\) 项多项式\(A(x)\) 的 \(0\) 次项为 \(1\) 时,各项系数在 \(\bmod \: p\) 意义
下有 \(A(x)^p=1 \pmod {x^{n}}\)。
数学归纳法:
- 当 \(n=1\) 时成立。
- 假设当 \(n=k\) 时成立,那么当 \(n=k+1\) 时,不妨设 \(A(x)=A_0(x)+bx^k\).
\[A(x)^p=(A_0(x)+bx^k)^p=\sum_{i=0}^p {p \choose i}A_0(x)^{p-i}(bx^{k})^i=\sum_{i=0}^p {p \choose i}A_0(x)^{p-i}b^i x^{ik} \pmod {x^{k+1}}
\]
当 \(i \ge 1\) 时 \(ik \ge k+2\),被舍去。
有
\[A(x)^p=\sum_{i=0}^1{p \choose i}A_0(x)^{p-i}b^i x^{ik} =A_0(x)^p+p \cdot A_0(x)^{p-1}bx^k =A_0(x)^p=1 \pmod {x^{k+1}}
\]
\(A(x)^k=\exp(k\ln {A(x)})\)
\(\exp(kx)=\sum_{i=0}^{+\infty} {k^ix^i \over i!}\)
因此可以对 \(k\) 取模。