P4238 【模板】多项式求逆

思路

多项式求逆就是对于一个多项式\(A(x)\),求一个多项式\(B(x)\),使得\(A(x)B(x) \equiv 1 \ (mod x^n)\)

假设现在多项式只有一项,显然\(B(x)\)的第0项(常数项)就是\(A(x)\)的第0项(常数项)的逆元(所以\(A(x)\)有没有逆元取决于\(A(x)\)的常数项有没有逆元)

那我们可以利用递归的方法,

现在要求

\[A(x)B(x) \equiv 1 (mod\ x^n) \]

假设有多项式\(B'(x)\),满足

\[A(x)B'(x)\equiv 1 (mod\ x^{\lfloor\frac{n}{2}\rfloor}) \]

则要求的\(B(x)\),必定也满足

\[A(x)B(x) \equiv 1 (mod\ x^{\lfloor\frac{n}{2}\rfloor}) \]

所以有

\[A(x)(B(x)-B'(x)) \equiv 0 (mod\ x^{\lfloor\frac{n}{2}\rfloor})\\B(x)-B'(x)\equiv 0 (mod\ x^{\lfloor\frac{n}{2}\rfloor})\\ (B(x)-B'(x))^2\equiv 0 (mod\ x^{n})\\ B^2(x)-2B(x)B'(x)+B'^2(x)\equiv 0 (mod\ x^{n})\\ \]

两侧都乘\(A(x)\)

\[B(x)-2B'(x)+2A(x)B'^2(x)\equiv 0 (mod \ x^n) \]

所以

\[B(x)\equiv B'(x)(2-A(x)B'(x)) \]

递归求解即可

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
using namespace std;
const int MOD = 998244353,G = 3, invG = 332748118;
int pow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)
            ans=(a*ans)%MOD;
        a=(a*a)%MOD;
        b>>=1;
    }
    return ans;
}
void NTT(int *a,int n,int opt){
    int lim=0;
    while((1<<lim)<n)
        lim++;
    n=(1<<lim);
    for(int i=0;i<n;i++){
        int t=0;
        for(int j=0;j<lim;j++)
            if((i>>j)&1)
                t|=(1<<(lim-j-1));
        if(t<i)
            swap(a[t],a[i]);
    }
    for(int i=2;i<=n;i<<=1){
        int len=i/2;
        int tmp=pow((opt)?G:invG,(MOD-1)/i);
        for(int j=0;j<n;j+=i){
            int mid=1;
            for(int k=j;k<j+len;k++){
                int t=a[k+len]*mid;
                a[k+len]=(a[k]-t+MOD)%MOD;
                a[k]=(a[k]+t)%MOD;
                mid=(mid*tmp)%MOD;
            }
        }
    }
    if(!opt){
        int invn=pow(n,MOD-2);
        for(int i=0;i<n;i++)
            a[i]=(a[i]*invn)%MOD;
    }
}
int c[300100],a[300100],b[300100],n;
void get_inv(int times,int *a,int *b){
    if(times==1){
        b[0]=pow(a[0],MOD-2);
        return;
    }
    get_inv((times+1)>>1,a,b);
    while((times<<1)>n)
        n<<=1;
    for(int i=0;i<times;i++){
        c[i]=a[i];
    }
    for(int i=times;i<n;i++){
        c[i]=0;
    }
    NTT(c,n,1);
    NTT(b,n,1);
    for(int i=0;i<n;i++){
        b[i]=((2-c[i]*b[i]%MOD)%MOD+MOD)%MOD*b[i]%MOD;
    }
    NTT(b,n,0);
    for(int i=times;i<n;i++){
        b[i]=0;
    }
}
signed main(){
    scanf("%lld",&n);
    int tx=n;
    for(int i=0;i<tx;i++)
        scanf("%lld",&a[i]);
    n=1;
    get_inv(tx,a,b);
    for(int i=0;i<tx;i++){
        printf("%lld ",(b[i]+MOD)%MOD);
    }
    return 0;
}
posted @ 2019-03-20 19:11  dreagonm  阅读(220)  评论(0编辑  收藏  举报