多项式乘法逆
多项式乘法逆
我们考虑如何求一个多项式 \(B(x)\) 使得其满足 \(A(x)B(x)\equiv 1 \bmod x^n\),其中 \(A\) 是一个 \(n-1\) 次多项式。
根据在我“多项式入门”博客中推导的式子,我们可以递归来做这个题。
有几个误区:
-
我们不能直接算出点值然后最后再 IDFT 来做这个题,原因是我们考虑每次做的多项式乘法都是一个新的多项式乘法,有不同的多项式长度,故不能这样做。
-
我们需要处理好边界问题,首先蝴蝶变换的大小是根据我们最终多项式的最高次数决定的,必须比最高次数大。其次要保证模的性质,不该有数的地方我们就让其为 \(0\)。
代码:
#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define uint unsigned int
#define ull unsigned long long
#define N 400010
#define M number
using namespace std;
const int INF=0x3f3f3f3f;
const int mod=998244353;
const int g=3;
const int invg=332748118;
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
int tr[N];
inline int ksm(int a,int b,int mod){int res=1;while(b){if(b&1) res=1ll*a*res%mod;a=1ll*a*a%mod;b>>=1;}return res;}
inline void Gettr(int n){
for(int i=0;i<n;i++) tr[i]=(tr[i>>1]>>1)|((i&1)?(n>>1):0);
}
inline void NTT(int *f,int len,int flag){
for(int i=0;i<len;i++) if(i<tr[i]) swap(f[i],f[tr[i]]);
for(int p=2;p<=len;p<<=1){
int md=ksm(g,(mod-1)/p,mod),l=p>>1;
if(flag==-1) md=ksm(md,mod-2,mod);
for(int i=0;i<len;i+=p){
int buf=1;
for(int j=i;j<i+l;j++){
int tt=1ll*f[j+l]*buf%mod;
f[j+l]=((f[j]-tt)%mod+mod)%mod;
f[j]=(f[j]+tt)%mod;buf=1ll*buf*md%mod;
}
}
}
}
int n,m,c[N],a[N],b[N];
inline void GetInv(int len,int *a,int *b){
if(len==1){b[0]=ksm(a[0],mod-2,mod);return;}
GetInv((len+1)>>1,a,b);m=1;while(m<(len<<1)) m<<=1;
Gettr(m);for(int i=0;i<len;i++) c[i]=a[i];
for(int i=len;i<m;i++) c[i]=0;
// printf("m=%d\n",m);
// for(int i=0;i<m;i++) printf("%d ",c[i]);
NTT(c,m,1);NTT(b,m,1);
for(int i=0;i<m;i++) b[i]=1ll*(2-1ll*b[i]*c[i]%mod+mod)%mod*b[i]%mod;
NTT(b,m,-1);int inv=ksm(m,mod-2,mod);for(int i=0;i<m;i++) b[i]=1ll*b[i]*inv%mod;
for(int i=len;i<m;i++) b[i]=0;
// printf("len=%d\n",len);
// for(int i=0;i<n;i++) printf("%d ",b[i]);puts("");
}
int main(){
// freopen("my.in","r",stdin);
// freopen("my.out","w",stdout);
read(n);for(int i=0;i<n;i++) read(a[i]);
GetInv(n,a,b);
for(int i=0;i<n;i++) printf("%d ",b[i]);
return 0;
}