[2019.3.25]多项式求逆
多项式求逆是什么
对于一个\(n\)次多项式\(F(x)\),要求一个小于等于\(n\)次的多项式\(G(x)\),满足
\(F(x)G(x)\equiv1(mod\ x^n)\)
\(mod\ x^n\)即只考虑所有多项式的前n项。
怎么做多项式求逆
显然,当\(F(x)\)次数为0,即只有常数项时,它的逆元就是常数项的逆元。
对于次数大于0的多项式我们假设我们已经递归求出\(F(x)\)在\(mod\ x^{\lceil\frac{n}{2}\rceil}\)意义下的逆\(H(x)\)。
也就是我们有
\(F(x)H(x)\equiv1(mod\ x^{\lceil\frac{n}{2}\rceil})\)
由\(F(x)G(x)\equiv1(mod\ x^n)\)易知\(F(x)G(x)\equiv1(mod\ x^{\lceil\frac{n}{2}\rceil})\)。
两式相减得
\(F(x)[H(x)-G(x)]\equiv0(mod\ x^{\lceil\frac{n}{2}\rceil})\)
我们有\(F(x)\not=0\),即\(H(x)-G(x)\equiv0(mod\ x^{\lceil\frac{n}{2}\rceil})\)
由于我们有若\(a\equiv b(mod\ p)\),则\(a^2\equiv b^2(mod\ x^2)\)
则两边平方得
\(H(x)^2-2G(x)H(x)+G(x)^2\equiv0(mod\ x^{2\times\lceil\frac{n}{2}\rceil})\)
.因为\(2\times\lceil\frac{x}{2}\rceil\ge n\)所以\(H(x)^2-2G(x)H(x)+G(x)^2\equiv0(mod\ x^n)\)
两边同乘\(F(x)\),由于\(F(x)G(x)\equiv1(mod\ x^n)\)
\(F(x)H(x)^2-2H(x)+G(x)\equiv0(mod\ x^n)\)
\(G(x)\equiv 2H(x)-F(x)H(x)^2(mod\ x^n)\)
\(G(x)\equiv H(x)[2-F(x)H(x)](mod\ x^n)\)
我们可以NTT实现多项式乘法,时间复杂度\(O(n\log^2n)\)。
code:
#include<bits/stdc++.h>
#define ci const int&
#define VAL(p,n,i) (i<n?p[i]:0)
using namespace std;
const int mod=998244353;
const int g=3;
int cpy[600010];
int POW(int x,int y){
int tot=1;
while(y)y&1?tot=1ll*tot*x%mod:0,x=1ll*x*x%mod,y>>=1;
return tot;
}
void NTT(vector<int>&f,ci l,ci len,ci op){
if(len&1)return;
for(int i=l;i<l+len;++i)cpy[i]=f[i];
int nw=l-1,ln=len>>1;
for(int i=l;i<l+ln;++i)f[i]=cpy[++nw],f[i+ln]=cpy[++nw];
NTT(f,l,ln,op),NTT(f,l+ln,ln,op);
int rt=POW(g,(mod-1)/len),t;
op?rt=POW(rt,mod-2):0;
nw=1;
for(int i=l;i<l+len;++i)cpy[i]=f[i];
for(int i=l;i<l+ln;++i,nw=1ll*nw*rt%mod)t=1ll*nw*cpy[i+ln]%mod,f[i]=(cpy[i]+t)%mod,f[i+ln]=(cpy[i]-t+mod)%mod;
}
vector<int>F;
vector<int>T;
vector<int>tmp;
vector<int>a;
vector<int>b;
vector<int>c;
int ts,sz,tg,inv;
void print(const vector<int>&x){
for(int i=0;i<x.size();++i)printf("%d ",x[i]);
}
vector<int>calc(const vector<int>&x,const vector<int>&y){//2y-x*y^2
ts=x.size()+y.size()+y.size()-2,sz=1,a.clear(),b.clear(),c.clear();
while(sz<ts)sz<<=1;
for(int i=0;i<x.size();++i)a.push_back(x[i]);
for(int i=0;i<y.size();++i)b.push_back(y[i]);
a.resize(sz),b.resize(sz),NTT(a,0,sz,0),NTT(b,0,sz,0);
for(int i=0;i<sz;++i)c.push_back((2-1ll*a[i]*b[i]%mod+mod)%mod*b[i]%mod);
NTT(c,0,sz,1),inv=POW(sz,mod-2);
for(int i=0;i<sz;++i)c[i]=1ll*c[i]*inv%mod;
return c;
}
int n,v;
vector<int>INV(const vector<int>&x){
if(x.size()==1)return T.resize(1),T[0]=POW(x[0],mod-2),T;
vector<int>G=x;
G.resize((x.size()+1)>>1),G=calc(x,INV(G)),G.resize(x.size());
return G;
}
int main(){
scanf("%d",&n);
for(int i=0;i<n;++i)scanf("%d",&v),F.push_back(v);
print(INV(F));
return 0;
}