多项式板子

FFT

\[w_n=\cos\frac{2\pi}{n}+\sin\frac{2\pi}{n} i \]

\[F(\omega_n^k)=A(\omega_{n/2}^k)+\omega_{n}^k\times B(\omega_{n/2}^k) \]

\[F(\omega_n^{k+n/2})=A(\omega_{n/2}^k)-\omega_{n}^k\times B(\omega_{n/2}^k) \]

struct Complex{
    db x,y;
    Complex(db x_=0,db y_=0):x(x_),y(y_){}
    Complex operator +(Complex a){return Complex(x+a.x,y+a.y);}
    Complex operator -(Complex a){return Complex(x-a.x,y-a.y);}
    Complex operator *(Complex a){return Complex(x*a.x-y*a.y,y*a.x+x*a.y);}
};

inline void FFT(){
    for(rint i=0;i<n;i++)
        if(i<rk[i]) swap(a[i],a[rk[i]]);
    for(rint p=2;p<=n;p<<=1){
        int len=p>>1;
        Complex ret(cos(2.0*PI/p),sin(2.0*PI/p)*opt);
        for(rint k=0;k<n;k+=p){
            Complex now=Complex(1.0,0.0);
            for(rint l=k;l<k+len;l++){
                Complex t=now*a[l+len];
                a[l+len]=a[l]-t;
                a[l]=a[l]+t;
                now=now*ret;
            }
        }
    }
}

int main(){
    for(int i=0;i<n;i++)
        rk[i]=(rk[i>>1]>>1)|((i&1)? n>>1:0);
    opt=1;FFT();
    for(rint i=0;i<n;++i) a[i]=a[i]*a[i];
    opt=-1;FFT();
    for(rint i=0;i<=m;++i) printf("%.0lf ",fabs(a[i].y)/n/2.0);
}

NTT

\(p\) 的原根 \(g\) 替换 \(\omega\),因为原根有类似的性质。

\[\varphi(p)=2^p\times r \]

其中 \(2^p\) 决定了最大长度。

void NTT(bool op,int n,ll *F){
    for(int i=0;i<n;i++)
        if(i<rk[i]) swap(F[i],F[rk[i]]);
    for(int p=2;p<=n;p<<=1){
        int len=p>>1;
        ll w=qpow(op? G:Gi,(Mod-1)/p);
        for(int k=0;k<n;k+=p){
            ll now=1;
            for(int l=k;l<k+len;l++){
                ll t=F[l+len]*now%Mod;
                F[l+len]=(F[l]-t+Mod)%Mod;
                F[l]=(F[l]+t)%Mod;
                now=now*w%Mod;
            }
        }
    }
}

Mul

分别 NTT 后,将点值乘起来,再 NTT 回去。注意最后要除 \(n\)

\(n\)\(m\) 都是最高次幂大小。注意清空 \(x\)\(y\)

inline void Cop(int n,ll *a,ll *b){for(int i=0;i<n;i++)a[i]=b[i];}
inline void Clear(int n,ll *F){for(int i=0;i<n;i++)F[i]=0;}
inline void Rk(int n){for(int i=0;i<n;i++)rk[i]=(rk[i>>1]>>1)|(i&1? n>>1:0);}

void Mul(ll *X,int n,int m,ll *a,ll *b){
    static ll x[N],y[N];
    Cop(n+1,x,a),Cop(m+1,y,b);
    for(m+=n,n=1;n<=m;n<<=1); Rk(n);
    NTT(1,n,x),NTT(1,n,y);
    for(int i=0;i<n;i++) x[i]=x[i]*y[i]%Mod;
    NTT(0,n,x); ll inv=qpow(n);
    for(int i=0;i<=m;i++) X[i]=x[i]*inv%Mod;
    Clear(n,x),Clear(n,y);
}

Inv

\(A(x)\) 的逆元 \(B(x)\)\(a_0\) 非零。

先求出 \(A(x)\) 的常数项的逆元,设为初始的 \(B(x)\)。现在已知

\[A(x) \equiv B(x) \pmod{x^n} \]

可以得到

\[A(x)B(x) \equiv 1 \pmod{x^n} \]

\[\big(A(x)B(x)-1\big)^2 \equiv 0 \pmod{x^{2n}} \]

\[A(x)\big(2B(x)-A(x)B(x)^2\big) \equiv 0 \pmod{x^{2n}} \]

新的 \(B(x)\) 就是 \(2B(x)-A(x)B(x)^2\) 。递归即可,复杂度 \(O(n \log n)\)

注意 \(n\) 是项数,也即多项式长度,而我们上述式子所倍增的是多项式最高次幂,也就是平方能得到的是最高次幂的倍增,不是长度的倍增。而且会发现最高次幂的倍增会慢于长度的倍增,这就会导致求出来的最后几项实际上是虚拟的。一个解决的办法是 \(n\) 取大于等于 \(2m\) 的值,这样虽然会慢,但求出来一定是对的。

void Inv(int n,ll *a,ll *b){
    static ll x[N];
    if(n==1){b[0]=qpow(a[0]);return;}
    Inv((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); Rk(n);
    Clear(n,x),Cop(m,x,a);
    NTT(1,n,x),NTT(1,n,b);
    for(rint i=0;i<n;i++)
        b[i]=b[i]*(2-x[i]*b[i]%Mod+Mod)%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<m;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

ln

给定 \(A(x)\),且 \(a_0=1\)。求 \(B(x)=\ln A(x)\)

求导,有

\[B'(x)=\frac{A'(x)}{A(x)} \]

求逆即可,得到 \(B'(x)\),再积分回去。

void ln(int n,ll *a,ll *b){
    static ll x[N];
    Clear(n,x); Inv(n,a,x);
    for(int i=0;i<n-1;i++)
        b[i]=a[i+1]*(i+1)%Mod; b[n-1]=0;
    Mul(x,n-1,n-1,b,x);
    for(int i=1;i<n;i++) b[i]=x[i-1]*qpow(i)%Mod; b[0]=0;
}

exp

\(B(x)=e^{A(x)}\)

\[g(B(x))=\ln B(x)-A(x)\equiv 0 \pmod {x^n} \]

也就是要求 \(g\) 的一个多项式根。假如现在已经知道了 \(B\) 的前 \(n\) 项,即

\[B(x)\equiv B_0(x) \pmod {x^n} \]

\(x=B_0(x)\) 处泰勒展开,有

\[\begin{align} 0 &=g(B_0(x)) \\ &=g(B_0(x))+g'(B_0(x))\big(B(x)-B_0(x)\big)+\frac{g''(B_0(x))}{2}\big(B(x)-B_0(x)\big)^2+\dots \\ &=g(B_0(x))+g'(B_0(x))\big(B(x)-B_0(x)\big) \pmod {x^{2n}} \end{align} \]

化简得

\[B(x)\equiv B_0(x)-\frac{g\big(B_0(x)\big)}{g'\big(B_0(x)\big)} \]

代入 \(g\)

\[B(x)\equiv B_0(x)\Big(1-\ln B_0(x)+A(x)\Big) \pmod {x^{2n}} \]

倍增的时候 \(m\) 也要翻倍,和 \(Inv\) 同理。

void exp(int n,ll *a,ll *b){
    static ll x[N],y[N];
    if(n==1){b[0]=1;return;}
    exp((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); 
    Clear(n,x),Clear(n,y),Cop(m,y,a),ln(m,b,x); // x=ln(b)
    NTT(1,n,x),NTT(1,n,y),NTT(1,n,b);
    for(rint i=0;i<n;i++) b[i]=(1-x[i]+y[i]+Mod)%Mod*b[i]%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<n;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

Sqrt

\(B(x)^2\equiv A(x)\),保证 \(a_0=1\).

\[\begin{align} B(x)& \equiv B_0(x) \pmod{x^n}\\ \big(B(x)- B_0(x) \big )^2&\equiv 0 \pmod{x^{2n}}\\ B^2(x)+B_0(x)^2&\equiv 2B(x)B_0(x) \pmod{x^{2n}}\\ A(x)+B_0(x)^2&\equiv 2B(x)B_0(x) \pmod{x^{2n}}\\ B(x)&\equiv \frac{A(x)+B_0(x)^2}{2B_0(x)} \pmod{x^{2n}} \end{align} \]

void Sqrt(int n,ll *a,ll *b){
    static ll x[N],y[N];
    if(n==1){b[0]=1;return;}
    Sqrt((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1);
    Clear(n,x),Clear(n,y),Inv(m,b,x);
    Cop(m,y,a),Mul(y,m-1,m-1,x,y);
    ll inv=qpow(2ll);
    for(rint i=0;i<m;i++) b[i]=(b[i]+y[i])%Mod*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

完整板子

#include<stdio.h>
#define rint register int

typedef long long ll;

inline int read(){
    int x=0,flag=1; char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')flag=0;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-48;c=getchar();}
    return flag? x:-x; 
}

const int N=(1<<22)+7;
const int Mod=998244353;
const int G=3;

ll qpow(ll x,int y=Mod-2){
    ll ret=1;
    while(y){
        if(y&1) ret=ret*x%Mod;
        x=x*x%Mod,y>>=1; 
    }
    return ret;
}

const int Gi=qpow(G);

int rk[N];
inline void swap(ll &x,ll &y){x^=y,y^=x,x^=y;}
void NTT(bool op,int n,ll *F){
    for(rint i=0;i<n;i++)
        if(i<rk[i]) swap(F[i],F[rk[i]]);
    for(rint p=2;p<=n;p<<=1){
        rint len=p>>1;
        ll w=qpow(op? G:Gi,(Mod-1)/p);
        for(rint k=0;k<n;k+=p){
            ll now=1;
            for(rint l=k;l<k+len;l++){
                ll t=F[l+len]*now%Mod;
                F[l+len]=(F[l]-t+Mod)%Mod;
                F[l]=(F[l]+t)%Mod;
                now=now*w%Mod;
            }
        }
    }
}

inline void Cop(int n,ll *a,ll *b){for(int i=0;i<n;i++)a[i]=b[i];}
inline void Clear(int n,ll *F){for(int i=0;i<n;i++)F[i]=0;}
inline void Rk(int n){for(int i=0;i<n;i++)rk[i]=(rk[i>>1]>>1)|(i&1? n>>1:0);}

void Mul(ll *X,int n,int m,ll *a,ll *b){
    static ll x[N],y[N];
    Cop(n+1,x,a),Cop(m+1,y,b);
    for(m+=n,n=1;n<=m;n<<=1); Rk(n);
    NTT(1,n,x),NTT(1,n,y);
    for(rint i=0;i<n;i++) x[i]=x[i]*y[i]%Mod;
    NTT(0,n,x); ll inv=qpow(n);
    for(rint i=0;i<=m;i++) X[i]=x[i]*inv%Mod;
    Clear(n,x),Clear(n,y);
}

void Inv(int n,ll *a,ll *b){
    static ll x[N];
    if(n==1){b[0]=qpow(a[0]);return;}
    Inv((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); Rk(n);
    Clear(n,x),Cop(m,x,a);
    NTT(1,n,x),NTT(1,n,b);
    for(rint i=0;i<n;i++)
        b[i]=b[i]*(2-x[i]*b[i]%Mod+Mod)%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<m;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

void ln(int n,ll *a,ll *b){
    static ll x[N];
    Clear(n,x); Inv(n,a,x);
    for(int i=0;i<n-1;i++)
        b[i]=a[i+1]*(i+1)%Mod; b[n-1]=0;
    Mul(x,n-1,n-1,b,x);
    for(int i=1;i<n;i++) b[i]=x[i-1]*qpow(i)%Mod; b[0]=0;
}

void exp(int n,ll *a,ll *b){
    static ll x[N],y[N];
    if(n==1){b[0]=1;return;}
    exp((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); 
    Clear(n,x),Clear(n,y),Cop(m,y,a),ln(m,b,x); // x=ln(b)
    NTT(1,n,x),NTT(1,n,y),NTT(1,n,b);
    for(rint i=0;i<n;i++) b[i]=(1-x[i]+y[i]+Mod)%Mod*b[i]%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<n;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

void Sqrt(int n,ll *a,ll *b){
    static ll x[N],y[N];
    if(n==1){b[0]=1;return;}
    Sqrt((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1);
    Clear(n,x),Clear(n,y),Inv(m,b,x);
    Cop(m,y,a),Mul(y,m-1,m-1,x,y);
    ll inv=qpow(2ll);
    for(rint i=0;i<m;i++) b[i]=(b[i]+y[i])%Mod*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

int n,m;
ll a[N],b[N];

int main(){
    n=read();
    for(rint i=0;i<n;i++) a[i]=read();
    Func();
    for(rint i=0;i<n;i++) printf("%lld ",b[i]);
}
posted @ 2021-06-25 14:41  Kreap  阅读(58)  评论(0编辑  收藏  举报