多项式求逆

前置知识

NTT

多项式取模: 多项式模 \(x^{n}\) 表示取多项式的前 \(n\)

多项式求逆

给定

\[f(x)=a_{0}+a_{1}x^{1}+a_{2}x^{2}+···+a_{n}x^{n} \]

求出

\[g(x)=b_{0}+b_{1}x^{1}+b_{2}x^{2}+···+b_{k}x^{k}\ \ \ (k\le n) \]

使得

\[f(x)g(x)\equiv 1\ \ \ (mod\ x^{n}) \]

做法

\(n=1\) 时,显然有 \(b_{0}=inv(a_{0})\).

\(n>1\) 时,假设我们已经知道 \(mod\ x^{\frac{n}{2}}\) 意义下的逆

\[f^{-1}(x)=g'(x) \]

那么有:

\[f(x)g'(x)\equiv 1\ \ \ (mod\ x^{\frac{n}{2}}) \]

且我们知道 :

\[f(x)g(x)\equiv 1\ \ \ (mod\ x^{\frac{n}{2}}) \]

两式相减,能得到 :

\[f(x)(g(x)-g'(x))\equiv 0\ \ \ (mod\ x^{\frac{n}{2}}) \]

可以同时除去 \(f(x)\) ,那么有:

\[g(x)-g'(x)\equiv 0\ \ \ (mod\ x^{\frac{n}{2}}) \]

两边平方,则:

\[g^2(x)-2g(x)g'(x)+g'^2(x)\equiv 0\ \ \ (mod\ x^n) \]

两边同乘 \(f(x)\),消掉 \(g(x)\)

\[g(x)-2g'(x)+f(x)g'^2(x)\equiv 0\ \ \ (mod\ x^n) \]

移一下项,就很好算了:

\[g(x)\equiv 2g'(x)-f(x)g'^2(x)\ \ \ (mod\ x^n) \]

\[g(x)\equiv g'(x)*(2-f(x)g'(x))\ \ \ (mod\ x^n) \]

我们可以愉快地通过迭代的方式求出 \(g(x)\).

板子题

CODE

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define N 300006
#define LL long long 
using namespace std;

const int G=3;
const int IG=332748118;
const int mod=998244353;

int n,bit,tot,rev[N];
int f[N],g[N],a[N];

inline int mul(int a,int b){return (LL)a*b%mod;}
inline int sub(int a,int b){return a-b<0?a-b+mod:a-b;};
inline int add(int a,int b){return a+b>mod?a+b-mod:a+b;}

inline int qr()
{
	char a=0;int w=1,x=0;
	while(a<'0'||a>'9'){if(a=='-')w=-1;a=getchar();}
	while(a<='9'&&a>='0'){x=(x<<3)+(x<<1)+(a^48);a=getchar();}
	return x*w;
}

inline int poww(int a,int x)
{
	int ans=1;
	while(x)
	{
		if(x&1)
			ans=mul(ans,a);
		a=mul(a,a);
		x>>=1;
	}
	return ans;
}

int exgcd(int a,int b,int &x,int &y)
{
	if(!b)
	{
		x=1,y=0;
		return a;
	}
	int d=exgcd(b,a%b,y,x);
	y-=a/b*x;
	return d;
}

inline int inv(int a)
{
	int x,y;
	exgcd(a,mod,x,y);
	return (x%mod+mod)%mod;
}

inline void init_rev(int len)
{
	bit=0;
	while((1<<bit)<(len<<1))
		bit++;
	tot=1<<bit;
	for(register int i=0;i<tot;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}

inline void NTT(int *a,int tot,int opt)
{
	for(register int i=0;i<tot;i++)
		if(i<rev[i])
			swap(a[i],a[rev[i]]);
	for(register int mid=1;mid<tot;mid<<=1)
	{
		int w1=poww(opt==1?G:IG,(mod-1)/(mid<<1));
		for(register int i=0;i<tot;i+=(mid<<1))
		{
			int wk=1;
			for(register int j=0;j<mid;j++)
			{
				int x=a[i+j];
				int y=mul(wk,a[i+j+mid]);
				a[i+j]=add(x,y);
				a[i+j+mid]=sub(x,y);
				wk=mul(wk,w1);
			}
		}
	}
	if(opt==-1)
	{
		int inv_=inv(tot);
		for(register int i=0;i<tot;i++)
			a[i]=mul(a[i],inv_);
	}
}

void poly_inv(int len,int *a,int *g)
{
	if(len==1)
	{
		g[0]=inv(a[0]);
		return ;
	}
	poly_inv((len+1)>>1,a,g);
	init_rev(len);
	for(register int i=0;i<len;i++)
		f[i]=a[i];
	for(register int i=len;i<tot;i++)
		f[i]=0;
	NTT(f,tot,1);
	NTT(g,tot,1);
	for(register int i=0;i<tot;i++)
		g[i]=mul(g[i],sub(2,mul(g[i],f[i])));
	NTT(g,tot,-1);
	for(register int i=len;i<tot;i++)
		g[i]=0;
}

int main()
{
	n=qr();
	for(register int i=0;i<n;i++)
		a[i]=qr();
	poly_inv(n,a,g);
	for(register int i=0;i<n;i++)
		printf("%d ",g[i]);
	printf("\n");
	return 0;
}
posted @ 2021-02-18 16:21  江北南风  阅读(154)  评论(0编辑  收藏  举报