多项式

这是优美的多项式家族

快速傅里叶变换(FFT)

问题:多项式乘法

原理先不写了,思想就是把系数表达转化为点值表达,点值运算之后再变回系数表达,复杂度\(O(nlogn)\)

点值选取的是负数域中的n次单位根

有时间会补上这块内容的

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
const int N = 4e6;
const double Pi = acos(-1.0);
using namespace std;
struct node
{
    double x,y;
}a[N + 5],b[N + 5],w[N + 5];
int n,m,maxn,rev[N + 5],lg;
node operator +(node a,node b)
{
    return (node){a.x + b.x,a.y + b.y};
}
node operator -(node a,node b)
{
    return (node){a.x - b.x,a.y - b.y};
}
node operator *(node a,node b)
{
    return (node){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};
}
void fft(node *a,int typ)
{
    for (int i = 0;i < maxn;i++)
        if (i < rev[i])
            swap(a[i],a[rev[i]]);
    for (int i = 1;i < maxn;i <<= 1)
        for (int j = 0;j < maxn;j += i << 1)
            for (int k = 0;k < i;k++)
            {
                node x = a[k + j],t = (node){w[i + k].x,w[i + k].y * typ} * a[k + j + i];
                a[k + j] = x + t;
                a[k + j + i] = x - t;
            }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i = 0;i <= n;i++)
        scanf("%lf",&a[i].x);
    for (int i = 0;i <= m;i++)
        scanf("%lf",&b[i].x);
    maxn = 1;
    while (maxn <= m + n)
        maxn <<= 1,lg++;
    for (int i = 0;i <= maxn;i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    for (int i = 1;i < maxn;i <<= 1)
        for (int j = 0;j < i;j++)
            w[i + j] = (node){cos(Pi * j / i),sin(Pi * j / i)}; 
    fft(a,1);
    fft(b,1);
    for (int i = 0;i < maxn;i++)
        a[i] = a[i] * b[i];
    fft(a,-1);
    for (int i = 0;i <= n + m;i++)
        printf("%d ",(int)(a[i].x / maxn + 0.1));
    return 0;
}

快速数论变换(NTT)

就是把问题转化为了在模意义下,于是我们可以选择和单位根有类似性质的原根,时间复杂度仍是\(O(nlogn)\)

#include <iostream>
#include <cstdio>
#include <algorithm>
const int N = 5e6;
const int P = 998244353;
using namespace std;
int n,m,rev[N + 5],maxn,lg,a[N + 5],b[N + 5],g[N + 5][3];
int mypow(int a,int x)
{
    int s = 1;
    while (x)
    {
        if (x & 1)
            s = 1ll * s * a % P;
        a = 1ll * a * a % P;
        x >>= 1;
    }
    return s;
}
void ntt(int *a,int typ)
{
    for (int i = 0;i < maxn;i++)
        if (i < rev[i])
            swap(a[i],a[rev[i]]);
    for (int i = 1;i < maxn;i <<= 1)
        for (int j = 0;j < maxn;j += i << 1)
            for (int k = 0;k < i;k++)
            {
                int x = a[k + j],t = 1ll * g[k + i][typ] * a[k + i + j] % P;
                a[k + j] = (x + t) % P;
                a[k + i + j] = ((x - t) % P + P) % P;
            }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i = 0;i <= n;i++)
        scanf("%d",&a[i]);
    for (int i = 0;i <= m;i++)
        scanf("%d",&b[i]);
    maxn = 1;
    while (maxn <= n + m)
        maxn <<= 1,lg++;
    for (int i = 0;i <= maxn;i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    for (int i = 1;i < maxn;i <<= 1)
    {
        int G1 = mypow(3,(P - 1) / (i << 1)),G2 = mypow(mypow(3,P - 2),(P - 1) / (i << 1));
        g[i][1] = 1;
        g[i][0] = 1;
        for (int j = 1;j < i;j++)
            g[i + j][1] = 1ll * g[i + j - 1][1] * G1 % P,g[i + j][0] = 1ll * g[i + j - 1][0] * G2 % P;
    }
    ntt(a,1);
    ntt(b,1);
    for (int i = 0;i < maxn;i++)
        a[i] = 1ll * a[i] * b[i] % P;
    ntt(a,0);
    int inv = mypow(maxn,P - 2);
    for (int i = 0;i <= n + m;i++)
        printf("%d ",1ll * a[i] * inv % P);
    return 0;
}

多项式求逆

问题:给定一个多项式\(F(x)\),求一个多项式\(G(x)\),满足\(F(x)G(x)\equiv 1(mod\ x^n)\)

假设我们已经求出了一个\(F(x)\)\(mod\ x^n\)下的逆\(G'(x)\),我们要求在\(mod\ x^{2n}\)下的逆\(G(x)\)

那么考虑

\[\begin{aligned} F(x)G'(x)&\equiv 1(mod\ x^n) \\ F(x)G(x)&\equiv 1(mod\ x^n) \\ F(x)(G(x)-G'(x))&\equiv 0(mod\ x^n) \\ F^2(x)(G(x)-G'(x))^2&\equiv 0(mod\ x^{2n}) \\ F^2(x)G^2(x)-2F^2(x)G(x)G'(x)+F^2(x)G'^2(x)&\equiv 0(mod\ x^{2n}) \\ 1-2F(x)G'(x)+F^2(x)G'^2(x)&\equiv 0(mod\ x^{2n}) \\ G(x)-2G'(x)+F(x)G'^2(x)&\equiv 0(mod\ x^{2n}) \\ G(x)&\equiv 2G'(x)-F(x)G'^2(x)(mod\ x^{2n}) \end{aligned} \]

于是就可以愉快地递归求解了,时间复杂度\(T(n)=T(n/2)+O(nlogn)=O(nlogn)\)

Code

int INVa[N + 5];
void INV(int *a,int *ans,int n)
{
    if (n == 1)
    {
        ans[0] = mypow(a[0],p - 2);
        return;
    }
    INV(a,ans,n + 1 >> 1);
    pre(n * 2);
    for (int i = 0;i < n;i++)
        INVa[i] = a[i];
    clear(INVa,maxn,n);
    ntt(INVa,1);
    ntt(ans,1);
    for (int i = 0;i < maxn;i++)
        ans[i] = (2ll * ans[i] % p - 1ll * INVa[i] * ans[i] % p * ans[i] % p) % p;
    ntt(ans,0);
    clear(ans,maxn,n);
}

多项式对数函数(多项式 ln)

问题:给出 \(n-1\) 次多项式 \(A(x)\),求一个 \(\bmod{\:x^n}\) 下的多项式 \(B(x)\),满足 \(B(x) \equiv \ln A(x)\).

对两边同时求导\(B'(x)\equiv \frac{A'(x)}{A(x)}\)

积分回去\(B(x)\equiv \int \frac{A'(x)}{A(x)}dx\)

然后就是求导公式和积分公式

\[x^{a'}=ax^{a-1} \]

\[\int x^adx=\frac{1}{a+1}x^{a+1} \]

Code

int Lna[N + 5],Lnb[N + 5];
void DOV(int *a,int *f,int n)
{
    for (int i = 1;i < n;i++)
        f[i - 1] = 1ll * i * a[i] % p;
    f[n - 1] = 0;
}
void DOVINV(int *a,int *f,int n)
{
    f[0] = 0;
    for (int i = 1;i < n;i++)
        f[i] = 1ll * mypow(i,p - 2) * a[i - 1] % p;
}
void Ln(int *a,int *ans,int n)
{
    DOV(a,Lna,n);
    pre(n * 2);
    clear(Lnb,maxn);
    INV(a,Lnb,n);
    pre(n * 2);
    clear(Lna,maxn,n);
    ntt(Lna,1);
    ntt(Lnb,1);
    for (int i = 0;i < maxn;i++)
        Lna[i] = 1ll * Lna[i] * Lnb[i] % p;
    ntt(Lna,0);
    DOVINV(Lna,ans,n);
    clear(ans,maxn,n);
}

多项式指数函数(多项式 exp)

问题:给出 \(n-1\) 次多项式 \(A(x)\),保证\(A_0=0\),求一个 \(\bmod{\:x^n}\) 下的多项式 \(B(x)\),满足 \(B(x) \equiv \text e^{A(x)}\)

考虑用牛顿迭代解决这个问题

\[B(x)\equiv e^{A(x)} \]

\[lnB(x)-A(x)\equiv 0 \]

\(F(B(x))=lnB(x)-A(x)\)

\(A(x)\)看作常数项,所以\(F'(B(x))=\frac{1}{B(x)}\)

代入牛顿迭代的式子有

\[B(x)\equiv B_0(x)-\frac{F(B(x))}{F'(B(x))} \]

\[B(x)\equiv B_0(x)(1-lnB_0(x)+A(x)) \]

倍增求解即可

Code

int expa[N + 5],expb[N + 5];
void exp(int *a,int *ans,int n)
{
    if (n == 1)
    {
        ans[0] = 1;
        return;
    }
    exp(a,ans,n + 1 >> 1);
    Ln(ans,expa,n);
    pre(n * 2);
    for (int i = 0;i < n;i++)
        expb[i] = a[i];
    clear(expb,maxn,n);
    ntt(ans,1);
    ntt(expa,1);
    ntt(expb,1);
    for (int i = 0;i < maxn;i++)
        ans[i] = 1ll * ans[i] * ((1 - expa[i] + expb[i]) % p) % p;
    ntt(ans,0);
    clear(ans,maxn,n);
}

多项式快速幂

问题:给定一个 \(n-1\) 次多项式 \(A(x)\),求一个在 \(\bmod\ x^n\) 意义下的多项式 \(B(x)\),使得 \(B(x) \equiv A^k(x) \ (\bmod\ x^n)\)

我们对两边先ln再exp可以得到

\[B(x)\equiv exp(k\times ln(A(x))) \]

于是\(k\)也可以取模了

然后注意到数据不一定保证\(A_0=1\),那么我们可以找到第一个非\(0\)的项\(a\),把\(A(x)\)的每一项都除以\(a\),变成\(\frac{A(x)}{a}\),并将后面的移到前面,这样就可以保证\(A_0=1\),最后再乘\(a^k\)并且处理\(0\)即可

Code

int pa[N + 5];
void mypow(int *a,int *ans,int n,int k)
{
    Ln(a,pa,n);
    for (int i = 0;i < n;i++)
        pa[i] = 1ll * pa[i] * k % p;
    exp(pa,ans,n);
}

多项式开根

问题:给定一个\(n-1\)次多项式\(A(x)\),求一个在\(\bmod\ x^n\)意义下的多项式\(B(x)\),使得\(B^2(x) \equiv A(x) \ (\bmod\ x^n)\)。若有多解,请取零次项系数较小的作为答案。

\(H^2(x)\equiv F(x)(mod\ x^n)\)

那么考虑

\[\begin{aligned} G(x)&\equiv H(x)(mod\ x^n) \\ G(x)-H(x)&\equiv 0(mod\ x^n) \\ (G(x)-H(x))^2&\equiv 0(mod\ x^2n) \\ G^2(x)-2H(x)G(x)+H^2(x)&\equiv 0(mod\ x^{2n}) \\ G(x)&\equiv \frac{F(x)+H^2(x)}{2H(x)}(mod\ x^{2n}) \end{aligned} \]

倍增即可,只有一项的时候需要用二次剩余求根号

不过其实也可以先ln再exp回去

Code

int sqra[N + 5],sqrtmp[N + 5];
void sqr(int *a,int *ans,int n)
{
    if (n == 1)
    {
        ans[0] = sq;
        return;
    }
    sqr(a,ans,n + 1 >> 1);
    pre(n * 2);
    clear(sqra,maxn);
    clear(sqrtmp,maxn);
    INV(ans,sqra,n);
    pre(n * 2);
    for (int i = 0;i < n;i++)
        sqrtmp[i] = a[i];
    ntt(sqra,1);
    ntt(sqrtmp,1);
    ntt(ans,1);
    int t = mypow(2,p - 2);
    for (int i = 0;i < maxn;i++)
        ans[i] = 1ll * ((sqrtmp[i] + 1ll * ans[i] * ans[i] % p) % p) * t % p * sqra[i] % p;
    ntt(ans,0);
    int inv = mypow(maxn,p - 2);
    for (int i = 0;i < n;i++)
        ans[i] = 1ll * ans[i] * inv % p;
    clear(ans,maxn,n);
}

多项式除法

问题:给定一个\(n\)次多项式\(F(x)\)和一个\(m\)次多项式\(G(x)\),求出多项式\(Q(x),R(x)\)满足:

  • \(Q(x)\)次数为\(n-m\)\(R(x)\)次数小于\(m\)
  • \(F(x)=Q(x)G(x)+R(x)\)

首先设一个\(n\)项多项式\(A(x)\),假设一个\(r\)操作使得\(A_r(x)=x^nA(\frac{1}{x})\)

那么可以看出\(A_r[i]=A[n-i]\)

然后考虑下面的式子

\[\begin{aligned} F(x)&=Q(x)G(x)+R(x) \\ F(\frac{1}{x})&=Q(\frac{1}{x})G(\frac{1}{x})+R(\frac{1}{x}) \\ x^nF(\frac{1}{x})&=x^{n-m}Q(\frac{1}{x})x^mG(\frac{1}{x})+x^{n-m+1}\cdot x^{m-1}R(\frac{1}{x}) \\ F_r(x)&=Q_r(x)G_r(x)+x^{n-m+1}R_r(x) \\ F_r(x)&\equiv Q_r(x)G_r(x)+x^{n-m+1}R_r(x)(mod\ x^{n-m+1}) \\ F_r(x)&\equiv Q_r(x)G_r(x)(mod\ x^{n-m+1}) \\ Q_r(x)&\equiv F_r(x)G^{-1}_r(x)(mod\ x^{n-m+1}) \end{aligned} \]

于是我们对\(G_r(x)\)求逆,然后求得\(Q_r(x)\),再带回得到\(Q(x)\)

最后根据\(R(x)=F(x)-Q(x)G(x)\)求得\(R(x)\)

时间复杂度\(O(nlogn)\)

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
const int P = 998244353;
const int N = 1e6;
using namespace std;
int mypow(int a,int x)
{
    int s = 1;
    while (x)
    {
        if (x & 1)
            s = 1ll * s * a % P;
        a = 1ll * a * a % P;
        x >>= 1;
    }
    return s;
}
int n,m,F[N + 5],G[N + 5],Q[N + 5],GR[N + 5],w[N + 5][3],maxn,lg,rev[N + 5],Gi[N + 5],c[N + 5],FR[N + 5];
void R(int *a,int *b,int n)
{
    for (int i = 0;i <= n;i++)
        b[i] = a[n - i];
}
void ntt(int *a,int typ)
{
    for (int i = 0;i < maxn;i++)
        if (i < rev[i])
            swap(a[i],a[rev[i]]);
    for (int i = 1;i < maxn;i <<= 1)
        for (int j = 0;j < maxn;j += i << 1)
            for (int k = 0;k < i;k++)
            {
                int x = a[j + k],t = 1ll * w[i + k][typ] * a[i + j + k] % P;
                a[j + k] = (x + t) % P;
                a[j + k + i] = ((x - t) % P + P) % P;
            }
}
void ntt_pre(int n)
{
    maxn = 1;
    lg = 0;
    while (maxn <= n)
        maxn <<= 1,lg++;
    for (int i = 0;i < maxn;i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));    
}
void INV(int n,int *a,int *b)
{
    if (n == 1)
    {
        b[0] = mypow(a[0],P - 2);
        return;
    }
    INV((n + 1) >> 1,a,b);
    ntt_pre(n << 1);
    for (int i = 0;i < n;i++)
        c[i] = a[i];
    for (int i = n;i < maxn;i++)
        c[i] = 0;
    ntt(c,1);
    ntt(b,1);
    for (int i = 0;i < maxn;i++)
        b[i] = ((2ll * b[i] % P - 1ll * c[i] * b[i] % P * b[i] % P) % P + P) % P;
    ntt(b,2);
    int inv = mypow(maxn,P - 2);
    for (int i = 0;i < n;i++)
        b[i] = 1ll * b[i] * inv % P;
    for (int i = n;i < maxn;i++)
        b[i] = 0;
}
void NR(int *a,int *b,int n)
{
    for (int i = 0;i <= n;i++)
        b[n - i] = a[i];
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i = 0;i <= n;i++)
        scanf("%d",&F[i]);
    for (int i = 0;i <= m;i++)
        scanf("%d",&G[i]);
    maxn = 1;
    while (maxn <= (n + m) * 2)
        maxn <<= 1;
    for (int i = 1;i < maxn;i <<= 1)
    {
        int G1 = mypow(3,(P - 1) / (i << 1)),G2 = mypow(mypow(3,P - 2),(P - 1) / (i << 1));
        w[i][1] = w[i][2] = 1;
        for (int j = 1;j < i;j++)
            w[i + j][1] = 1ll * w[i + j - 1][1] * G1 % P,w[i + j][2] = 1ll * w[i + j - 1][2] * G2 % P;
    }
    R(G,GR,m);
    INV(n - m + 2,GR,Gi);
    R(F,FR,n);   
    ntt_pre(n * 2 - m + 2);
    ntt(FR,1);    
    ntt(Gi,1);
    for (int i = 0;i < maxn;i++)
        Gi[i] = 1ll * Gi[i] * FR[i] % P;
    ntt(Gi,2);
    int inv = mypow(maxn,P - 2); 
    for (int i = 0;i < maxn;i++)
        Gi[i] = 1ll * Gi[i] * inv % P;
    NR(Gi,Q,n - m);
    for (int i = 0;i <= n - m;i++)
        printf("%d ",Q[i]);
    cout<<endl;
    for (int i = n - m + 1;i < maxn;i++)
        Q[i] = 0;
    ntt_pre(n + m);
    ntt(Q,1);
    ntt(G,1);
    ntt(F,1);
    for (int i = 0;i < maxn;i++)    
        F[i] = ((F[i] - 1ll * Q[i] * G[i] % P) % P + P) % P;
    ntt(F,2);
    inv = mypow(maxn,P - 2);
    for (int i = 0;i < m;i++)
        printf("%d ",1ll * F[i] * inv % P);
    return 0;
}
posted @ 2020-06-08 20:37  eee_hoho  阅读(168)  评论(0编辑  收藏  举报