多项式求逆元详解+模板 【洛谷P4238】多项式求逆

概述

多项式求逆元是一个非常重要的知识点,许多多项式操作都需要用到该算法,包括多项式取模,除法,开跟,求ln,求exp,快速幂。用快速傅里叶变换和倍增法可以在$O(n log n)$的时间复杂度下求出一个$n$次多项式的逆元。

 

前置技能

快速数论变换(NTT),求一个数$x$在模$p$意义下的乘法逆元。

 

多项式的逆元

给定一个多项式$A(x)$,其次数为$deg_A$,若存在一个多项式$B(x)$,使其满足$deg_B≤deg_A$,且$A(x)\times B(x) \equiv 1 (mod\ x^n)$,则$B(x)$即为$A(x)$在模$x^n$意义下的的乘法逆元。

 

求多项式的逆元

我们不妨假设,$n=2^k,k∈N$。

若$n=1$,则$A(x)\times B(x) \equiv a_0\times b_0 \equiv 1 (mod\ x^1)$。其中$a_0$,$b_0$表示多项式$A$和多项式$B$的常数项。

若需要求出$b_0$,直接用费马小定理求出$a_0$的乘法逆元即可。

当$n>1$时:

我们假设在模$x^{\frac{n}{2}}$的意义下$A(x)$的逆元$B'(x)$我们已经求得。

依据定义,则有

$A(x)B'(x)\equiv 1 (mod\ x^{\frac{n}{2}})$          $(1)$

对$(1)$式进行移项得

$A(x)B'(x)-1\equiv 0 (mod\ x^{\frac{n}{2}})$          $(2)$

然后对$(2)$式等号两边平方,得

$A^2(x)B'^2(x)-2A(x)B'(x)+1\equiv 0(mod\ x^{n})$          $(3)$

将常数项移动到等式右侧,得

$A^2(x)B'^2(x)-2A(x)B'(x)\equiv -1(mod\ x^{n})$          $(4)$

将等式两边去相反数,得

$2A(x)B'(x)-A^2(x)B'^2(x)\equiv 1(mod\ x^{n})$          $(5)$

下面考虑回我们需要求的多项式$B(x)$,依据定义,其满足

$A(x)B(x)\equiv 1(mod\ x^{n})          $(6)$

将$(5)-(6)$并移项,得

$A(x)B(x)\equiv 2A(x)B'(x)-A^2(x)B'^2(x)(mod\ x^{n})$          $(7)$

等式两边约去$A(x)$,得

$B(x)\equiv 2B'(x)-A(x)B'^2(x)(mod\ x^{n})$          $(8)$

 

 

显然,我们可以用上述式子求出$B(x)$。

这一步的计算我们可以使用$NTT$,时间复杂度为$O(n log n)$。

我们可以通过递归的方法,求解出$B(x)$。

时间复杂度$T(n)=T(\dfrac{n}{2})+O(n log n)=O(n log n)$。

 

洛谷上有一道题目就叫做多项式求逆元(点这里),可以先做下那一题。

模板如下:

 

 1 #include<bits/stdc++.h>
 2 #define M (1<<19)
 3 #define L long long
 4 #define MOD 998244353
 5 #define G 3
 6 using namespace std;
 7 
 8 L pow_mod(L x,L k){
 9     L ans=1;
10     while(k){
11         if(k&1) ans=ans*x%MOD;
12         x=x*x%MOD; k>>=1;
13     }
14     return ans;
15 }
16 
17 void change(L a[],int n){
18     for(int i=0,j=0;i<n-1;i++){
19         if(i<j) swap(a[i],a[j]);
20         int k=n>>1;
21         while(j>=k) j-=k,k>>=1;
22         j+=k;
23     }
24 }
25 void NTT(L a[],int n,int on){
26     change(a,n);
27     for(int h=2;h<=n;h<<=1){
28         L wn=pow_mod(G,(MOD-1)/h);
29         for(int j=0;j<n;j+=h){
30             L w=1;
31             for(int k=j;k<j+(h>>1);k++){
32                 L u=a[k],t=w*a[k+(h>>1)]%MOD;
33                 a[k]=(u+t)%MOD;
34                 a[k+(h>>1)]=(u-t+MOD)%MOD;
35                 w=w*wn%MOD;
36             }
37         }
38     }
39     if(on==-1){
40         L inv=pow_mod(n,MOD-2);
41         for(int i=0;i<n;i++) a[i]=a[i]*inv%MOD;
42         reverse(a+1,a+n);
43     }
44 }
45 
46 void getinv(L a[],L b[],int n){
47     if(n==1){b[0]=pow_mod(a[0],MOD-2); return;}
48     static L c[M],d[M];
49     memset(c,0,n<<4); memset(d,0,n<<4);
50     getinv(a,c,n>>1);
51     for(int i=0;i<n;i++) d[i]=a[i];
52     NTT(d,n<<1,1); NTT(c,n<<1,1);
53     for(int i=0;i<(n<<1);i++) b[i]=(2*c[i]-d[i]*c[i]%MOD*c[i]%MOD+MOD)%MOD;
54     NTT(b,n<<1,-1);
55     for(int i=0;i<n;i++) b[n+i]=0;
56 }
57 L a[M]={0},b[M]={0};
58 int main(){
59     int n,N; scanf("%d",&n);
60     for(int i=0;i<=n;i++) scanf("%lld",a+i);
61     for(N=1;N<=n;N<<=1);
62     getinv(a,b,N);
63     for(int i=0;i<=n;i++) printf("%lld ",b[i]);
64 }

 

posted @ 2018-05-29 20:33  AlphaInf  阅读(3278)  评论(1编辑  收藏  举报