多项式求逆

多项式求逆

题意

给出一个多项式 $ G( x ) $ 求一个多项式 $ F( x ) $ 满足 $ F( x ) * G( x ) = 1 ( mod x^n )$,系数对998244353取模。

解法

假设现在我们已经求出了 \(G( x )\) 在膜 $ x ^ { [ \frac {n} {2} ] }$ 下的逆元多项式 $ F'\( 那么我们有 \) G * F' = 1 ( mod x ^ { [ \frac {n} {2} ] } )$ $ G * F = 1 ( mod x ^ n ) $
∴ $ ( F' - F )^2 = 0 ( mod x^n )\( 拆开则有: \) F'^2 - 2F'F + F^2 = 0 ( mod x^n ) $
左右同乘 $ G $ 有:
$ GF'^2 - 2F' + GF = 0 ( mod x^n ) \( 移项有: \) F = 2F' - GF'^2 ( mod x^n )$
然后我们就有递推式了。同时我们可以知道,当 $ n = 1$ 时,$ F[0] = inv( G[0] ) $,所以一个多项式有逆元的充要条件为他的常数项有逆元。
代码如下:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
#include <queue>
#define INF 2139062143
#define MAX 0x7ffffffffffffff
#define del(a,b) memset(a,b,sizeof(a))
#define Rint register int
using namespace std;
typedef long long ll;
template<typename T>
inline void read(T&x)
{
    x=0;T k=1;char c=getchar();
    while(!isdigit(c)){if(c=='-')k=-1;c=getchar();}
    while(isdigit(c)){x=x*10+c-'0';c=getchar();}x*=k;
}
const int maxn=(1e5+5)*3;
const int mod=998244353;
const int g=3;
int mul(int x,int y) {return 1ll*x*y%mod;}
int add(int x,int y) {return (x+y)%mod;}
int pul(int x,int y) {return (x-y+mod)%mod;}
int poww(int a,int b){
    int ans=1;
    while(b){
        if(b&1) ans=mul(ans,a);
        a=mul(a,a);
        b>>=1;
    }
    return ans;
}
int inv(int x) {return poww(x,mod-2);}
void ntt(int n,int f,int *a){
    for(int i=0,j=0;i<n;i++){
        if(i<j) swap(a[i],a[j]);
        for(int l=n>>1;(j^=l)<l;l>>=1);
    }
    for(int i=1;i<n;i<<=1){
        int gn=poww(g,(mod-1)/(i<<1));
        if(f==-1) gn=inv(gn);
        for(int j=0;j<n;j+=(i<<1)){
            int g=1;
            for(int k=0;k<i;k++,g=mul(g,gn)){
                int x=a[j+k],y=mul(g,a[i+j+k]);
                a[j+k]=add(x,y);
                a[i+j+k]=pul(x,y);
            }
        }
    }
    if(f==-1){
        int ni=inv(n);
        for(int i=0;i<n;i++) a[i]=mul(a[i],ni);
    }
}
void inv_p(int deg,int *a,int *b,int *temp){
    if(deg==1){
        b[0]=inv(a[0]);
        return;
    }
    inv_p((deg+1)>>1,a,b,temp);
    int p=1;
    for(;p<=deg*2;p<<=1);
    for(int i=0;i<deg;i++) temp[i]=a[i];
    fill(temp+deg,temp+p,0);
    ntt(p,1,temp);ntt(p,1,b);
    for(int i=0;i<p;i++) b[i]=pul( mul( 2 , b[i] ) , mul( b[i] , mul( b[i] , temp[i] )) );
    ntt(p,-1,b);
    fill(b+deg,b+p,0);
}
int a[maxn],b[maxn],c[maxn];
int main()
{
    int n;
    read(n);
    for(int i=0;i<n;i++) read(a[i]);
    inv_p(n,a,b,c);
    for(int i=0;i<n;i++) printf("%d ",b[i]);
    return 0;
}
posted @ 2018-08-20 22:09  Mr_asd  阅读(261)  评论(0编辑  收藏  举报