总结:多项式的运算

多项式的小总结

前置芝士🧀:

多项式的各种运算

这些运算都是在模意义下进行的运算,但多项式的取模运算与整数的取模运算有些不同。

多项式对 \(x^n\) 取模的意思是舍弃 \(x^n\) 以及更高次的部分。

多项式求逆

  • 对于一个多项式 \(A(x)\) ,如果存在 \(B(x)\) 使得

\[A(x)B(x)\equiv 1\pmod {x^n} \]

  • 那么称 \(B(x)\)\(A(x)\)\(mod\: x^n\) 意义下的逆元 \((inverse\:element)\),记作 \(A^{-1}(x)\)
  • 取模意义下,没有模数的逆元是没有意义的,因为不同的模数对应不一样的逆元。

推导

  • 考虑用倍增法求解。

  • 假如我们现在已经求出了 \(A(x)\)\(mod\:x^{\frac{n}{2}}\) 意义下的逆元 \(B_0(x)\) ,即

    \[A(x)B_0(x)\equiv 1\pmod {x^{\frac{n}{2}}} \]

  • \[\because A(x)B(x)\equiv 1\pmod {x^{\frac{x}{2}}} \]

  • 两式相减并消去 \(A(x)\)

    \[\therefore B(x)-B_0(x)\equiv 0\pmod {x^{\frac{n}{2}}} \]

  • 再同时平方

    \[B^2(x)-2B(x)B_0(x)+B_0^2(x)\equiv 0\pmod {x^n}B^2(x)-2B(x)B_0(x)+B_0^2(x)\equiv 0\pmod {x^n} \]

  • 乘上 \(A(x)\),即可消去 \(B(x)\)

    \[B(x)-2B_0(x)+A(x)B_0^2\equiv 0\pmod {x^n} \]

  • 所以得到递推式

    \[B(x)=B_0(x)(2-A(x)B_0(x))\pmod {x^n} \]

  • 边界:当 \(n=1\) 时,\(B_0(x)\) 即为 \(A(x)\) 常数项的逆元。

  • 然后就可以在 \(O(nlogn)\) 的时间复杂度内求逆啦

代码

递归版:

#include <iostream>			//递归
#include <cstdio>

using namespace std;

const int maxn=4e5+10,mod=998244353,g=3,gn=332748118;

int p=1,bit,inver;
int f[maxn],h[maxn],c[maxn],rev[maxn];

inline int power(long long a,int x) {
	long long ans=1;
	while(x) {
		if(x&1) ans=(ans*a)%mod;
		a=(a*a)%mod;
		x>>=1;
	}
	return ans;
}

inline void ntt(int *t,int len,int inv) {
	for(int i=0;i<len;i++) {
		if(i<rev[i]) swap(t[i],t[rev[i]]);
	}
	for(int mid=1;mid<len;mid<<=1) {
		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
		int d=mid<<1;
		for(int l=0;l<len;l+=d) {
			int now=1;
			for(int i=0;i<mid;i++) {
				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
				t[l+i]=(x+y)%mod;
				t[l+mid+i]=(x-y+mod)%mod;
				now=(long long)now*unit%mod;
			}
		}
	}
	if(inv==-1) for(int i=0;i<len;i++) {
		t[i]=(long long)t[i]*inver%mod;
	}
}

inline void solve(int deg) {
	if(deg==1) {
		h[0]=power(f[0],mod-2);
		return ;
	}
	solve((deg+1)>>1);
	while(p<(deg<<1)) {p<<=1;bit++;}
	inver=power(p,mod-2);
	for(int i=1;i<p;i++) {
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
	}
	for(int i=0;i<deg;i++) c[i]=f[i];
	ntt(c,p,1),ntt(h,p,1);
	for(int i=0;i<p;i++) 
		h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
	ntt(h,p,-1);
	for(int i=deg;i<p;i++) h[i]=0;			//必须要归零
}

int main() {
	int n=read();
	for(int i=0;i<n;i++) f[i]=read();
	solve(n);
	for(int i=0;i<n;i++) printf("%d ",h[i]);
	putchar('\n');
	return 0;
}

// by pycr

递推版:

#include <iostream>			//递推
#include <cstdio>

using namespace std;

const int maxn=4e5+10,mod=998244353,g=3,gn=332748118;

int p=1,bit,inver;
int f[maxn],h[maxn],t[maxn],rev[maxn];

inline int power(long long a,int x) {
	long long ans=1;
	while(x) {
		if(x&1) ans=(ans*a)%mod;
		a=(a*a)%mod;
		x>>=1;
	}
	return ans;
}

inline void ntt(int *t,int len,int inv) {
	for(int i=0;i<len;i++) {
		if(i<rev[i]) swap(t[i],t[rev[i]]);
	}
	for(int mid=1;mid<len;mid<<=1) {
		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
		int d=mid<<1;
		for(int l=0;l<len;l+=d) {
			int now=1;
			for(int i=0;i<mid;i++) {
				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
				t[l+i]=(x+y)%mod;
				t[l+mid+i]=(x-y+mod)%mod;
				now=(long long)now*unit%mod;
			}
		}
	}
	if(inv==-1) for(int i=0;i<len;i++) {
		t[i]=(long long)t[i]*inver%mod;
	}
}

signed main() {
	int n=read();
	for(int i=0;i<n;i++) f[i]=read();
	h[0]=power(f[0],mod-2);
	for(int i=2;i<=(n<<1)-2;i<<=1) {
		while(p<=2*i-3) {p<<=1,bit++;}
		for(int j=1;j<p;j++) {
			rev[j]=(rev[j>>1]>>1)|((j&1)<<(bit-1));
		}
		inver=power(p,mod-2);
		for(int j=0;j<i;j++) t[j]=f[j];
		ntt(t,p,1),ntt(h,p,1);
		for(int j=0;j<p;j++) h[j]=h[j]*(2-(long long)t[j]*h[j]%mod+mod)%mod;
		ntt(h,p,-1);
		for(int j=i;j<p;j++) h[j]=0;
	}
	for(int i=0;i<n;i++) printf("%d ",h[i]);
	putchar('\n');
	return 0;
}

// by pycr
  • 测出来都在 \(900ms\) 左右,相差 \(1ms\)简直奇慢无比……
  • \(Tips:\) 每一次递归(递推)结束后,都需要把 \(h\) 数组清零,不然会影响答案的正确性。

多项式对数函数

  • \(B(x)\equiv \ln\:A(x)\pmod {x^n}\)

推导

  • \(\ln\) 看着太碍眼了,有没有什么能够消除 \(\ln\) 的方法?

  • 自然是有的,联系到我们之前学的微积分知识可以想到,用链规则对 \(\ln\:A(x)\) 求导可以得到 \(\frac{A'(x)}{A(x)}\) ,学过多项式的逆就很容易计算这个式子的答案了,最后对其积分就行。即:

    \[\ln\:A(x)=\int \frac{A'(x)}{A(x)}dx \]

  • \(Tips:\) 多项式常数项为 \(1\) 时才能取 \(\ln\) ,取后常数项为 \(0\)

代码

#include <iostream>
#include <cstdio>

using namespace std;

const int maxn=4e5+10,mod=998244353,g=3,gn=(mod+1)/g;

int p=1,bit,inver;
int f[maxn],h[maxn],c[maxn],rev[maxn];

inline int power(long long a,int x) {
	long long ans=1;
	while(x) {
		if(x&1) ans=(ans*a)%mod;
		a=(a*a)%mod;
		x>>=1;
	}
	return ans;
}

inline void ntt(int *t,int len,int inv) {
	for(int i=0;i<len;i++) {
		if(i<rev[i]) swap(t[i],t[rev[i]]);
	}
	for(int mid=1;mid<len;mid<<=1) {
		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
		int d=mid<<1;
		for(int l=0;l<len;l+=d) {
			int now=1;
			for(int i=0;i<mid;i++) {
				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
				t[l+i]=(x+y)%mod;
				t[l+mid+i]=(x-y+mod)%mod;
				now=(long long)now*unit%mod;
			}
		}
	}
	if(inv==-1) for(int i=0;i<len;i++) {
		t[i]=(long long)t[i]*inver%mod;
	}
}

inline void getinv(int *f,int *h,int deg) {
	if(deg==1) {
		h[0]=power(f[0],mod-2);
		return ;
	}
	getinv(f,h,(deg+1)>>1);
	while(p<(deg<<1)) {p<<=1;bit++;}
	inver=power(p,mod-2);
	for(int i=1;i<p;i++) 
		rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
	memcpy(c,f,deg*4);
	ntt(h,p,1);ntt(c,p,1);
	for(int i=0;i<p;i++) 
		h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
	ntt(h,p,-1);
	for(int i=deg;i<p;i++) h[i]=0;
}

inline void derivative(int *t,int len) {
	for(int i=1;i<len;i++) {
		t[i-1]=(long long)i*t[i]%mod;
	}
	t[len-1]=0;
}

inline void integrate(int *t,int len) {
	for(int i=len-1;i;i--) {
		t[i]=(long long)t[i-1]*power(i,mod-2)%mod;
	}
	t[0]=0;
}

int main() {
	int n=read();
	for(int i=0;i<n;i++) f[i]=read();
	getinv(f,h,n);
	derivative(f,n);
	ntt(f,p,1),ntt(h,p,1);
	for(int i=0;i<p;i++) f[i]=(long long)f[i]*h[i]%mod;
	ntt(f,p,-1);
	integrate(f,n);
	for(int i=0;i<n;i++) printf("%d ",f[i]);
	putchar('\n');
	return 0;
}

// by pycr

牛顿迭代

??怎么乱入啊? 牛顿迭代也是多项式运算中比较重要的一部分。

  • 多项式的牛顿迭代可不是用来在实数域和复数域上近似求解方程的。

  • 其用来求解以下方程中的 \(B(x)\)

    \[G(B(x))\equiv 0\pmod {x^n} \]

  • 还是考虑倍增法:假设我们已经求出了 \(\frac{n}{2}\) 次多项式 \(B_0(x)\) 使得:

    \[G(B_0(x))\equiv 0\pmod {x^{\frac{n}{2}}} \]

  • 结合之前泰勒展开的知识,将其在 \(B_0(x)\) 处泰勒展开:

    \[\sum_{i=0}^{+\infty}\frac{G^{(i)}(B_0(x))}{i!}(B(x)-B_0(x))^i\equiv 0\pmod {x^n} \]

    因为 \(B(x)-B_0(x)\)\(x^{\frac{n}{2}}\) 次项之下的系数都为 \(0\),所以其平方或者变成更高次幂之后在 \(mod\:x^n\) 意义下都为 \(0\),所以可以直接丢弃​​。

  • 那么原式就变为

    \[G(B(x))\equiv G(B_0(x))+G'(B_0(x))(B(x)-B_0(x))\pmod {x^n} \]

    因为 \(G(B(x))\equiv 0\pmod {x^n}\),得到

    \[B(x)\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}\pmod {x^n} \]

  • 然后就可以愉快的递归(递推)啦。

考虑用牛顿迭代实现多项式求逆

  • 其实很简单

  • \(G(B(x))=\frac{1}{B(x)}-A(x)\equiv 0\pmod {x^n}\)

  • \[\begin{aligned} B(x)&\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}&\pmod {x^n}\\ &\equiv B_0(x)-\frac{\frac{1}{B_0(x)}-A(x)}{-\frac{1}{B_0^2(x)}}&\pmod {x^n}\\ &\equiv 2\cdot B_0(x)-B_0^2A(x)&\pmod {x^n}\\ &\equiv B_0(x)(2-B_0(x)A(x))&\pmod {x^n} \end{aligned} \]

多项式指数函数

  • \(B(x)\equiv e^{A(x)}\pmod {x^n}\)

推导

  • 这个需要用到牛顿迭代。不然我之前讲迭代干嘛?

  • 考虑对两边同时取自然对数:

    \[\ln B(x)\equiv A(x)\\ \]

  • 设函数 \(G(B(x))=\ln B(x)-A(x)\),套用牛顿迭代得:

    \[\begin{aligned} B(x)&\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}&\pmod {x^n}\\ &\equiv B_0(x)-\frac{\ln B_0(x)-A(x)}{\frac{1}{B_0(x)}}&\pmod {x^n}\\ &\equiv B_0(x)(1-\ln B_0(x)-A(x))&\pmod {x^n} \end{aligned} \]

    结合之前的多项式对数函数即可。

  • \(Tips:\) 多项式常数项为 \(0\) 时才能取 \(\exp\) ,取后常数项为 \(1\)

代码

#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

const int maxn=4e5+10,mod=998244353,g=3,gn=(mod+1)/g;

int p=1,bit,inver;
int f[maxn],h[maxn],c[maxn],rev[maxn];
int h_ln[maxn],c_e[maxn],f_inv[maxn];

inline int power(long long a,int x) {
	long long ans=1;
	while(x) {
		if(x&1) ans=(ans*a)%mod;
		a=(a*a)%mod;
		x>>=1;
	}
	return ans;
}

inline void ntt(int *t,int len,int inv) {
	for(int i=0;i<len;i++) {
		if(i<rev[i]) swap(t[i],t[rev[i]]);
	}
	for(int mid=1;mid<len;mid<<=1) {
		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
		int d=mid<<1;
		for(int l=0;l<len;l+=d) {
			int now=1;
			for(int i=0;i<mid;i++) {
				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
				t[l+i]=(x+y)%mod;
				t[l+mid+i]=(x-y+mod)%mod;
				now=(long long)now*unit%mod;
			}
		}
	}
	if(inv==-1) for(int i=0;i<len;i++) {
		t[i]=(long long)t[i]*inver%mod;
	}
}

inline void getinv(int *f,int *h,int deg) {
	if(deg==1) {
		h[0]=power(f[0],mod-2);
		return ;
	}
	getinv(f,h,(deg+1)>>1);
	while(p<(deg<<1)) {p<<=1;bit++;}
	inver=power(p,mod-2);
	for(int i=1;i<p;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
	memcpy(c,f,deg*4);
	ntt(h,p,1);ntt(c,p,1);
	for(int i=0;i<p;i++) h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
	ntt(h,p,-1);
	for(int i=deg;i<p;i++) h[i]=0;
}

inline void derivative(int *src,int *t,int len) {
	for(int i=1;i<len;i++) {
		t[i-1]=(long long)i*src[i]%mod;
	}
	t[len-1]=0;
}

inline void integrate(int *t,int len) {
	for(int i=len-1;i;i--) {
		t[i]=(long long)t[i-1]*power(i,mod-2)%mod;
	}
	t[0]=0;
}

inline void getln(int *src,int *f,int len) {
	p=1,bit=0;
	memset(f_inv,0,sizeof(f_inv));			//必须要清零,因为ntt会算上deg之后的系数,就会超出p的范围
	memset(c,0,sizeof(c));
	getinv(src,f_inv,len);
	derivative(src,f,len);
	ntt(f,p,1),ntt(f_inv,p,1);
	for(int i=0;i<p;i++) f[i]=(long long)f[i]*f_inv[i]%mod;
	ntt(f,p,-1);
	integrate(f,len);
}

inline void getexp(int *f,int *h,int deg) {
	if(deg==1) {
		h[0]=1;
		return ;
	}
	getexp(f,h,(deg+1)>>1);
	memset(h_ln,0,sizeof(h_ln));			//清零,避免爆ntt
	getln(h,h_ln,deg);
	memcpy(c_e,f,deg*4);
	ntt(c_e,p,1),ntt(h,p,1);ntt(h_ln,p,1);
	for(int i=0;i<p;i++) 
		h[i]=h[i]*(1ll-h_ln[i]+c_e[i]+mod)%mod;
	ntt(h,p,-1);
	for(int i=deg;i<p;i++) h[i]=0;
}

int main() {
	int n=read();
	for(int i=0;i<n;i++) f[i]=read();
	getexp(f,h,n);
	for(int i=0;i<n;i++) printf("%d ",h[i]);
	putchar('\n');
	return 0;
}

// by pycr
  • \(Important:\) 为什么代码中会有三个 memset 呢?我在 \(FFT\&NTT\) 的总结中也有提及,因为如果在运算的时候后面的系数不为 \(0\) 的话,乘出来的实际结果可能就会大于所预估的长度 \(p\)。实际上后面有没有系数在模意义下是不会影响结果的,错误的真正原因是因为把 \(NTT\) 乘爆了。后面的系数不会影响结果的前提是 \(NTT\) 能够得到正确的多项式。简而言之:如果原本乘出来的结果的最高次项为 \(x^{n-1}\),那么就一定至少要有 \(n\) 个点,而后面的系数则有可能导致实际的多项式会有更高次项,超出我们预估的点数。

多项式开根

  • \(B^2(x)\equiv A(x)\pmod {x^n}\)

推导

  • 仍然是牛顿迭代。

  • \(G(B(x))=B^2(x)-A(x)\),则

    \[\begin{aligned} B(x)&\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}&\pmod {x^n}\\ &\equiv B_0(x)-\frac{B_0^2(x)-A(x)}{2B_0(x)}&\pmod {x^n}\\ &\equiv \frac{B_0^2(x)+A(x)}{2B_0(x)}&\pmod {x^n} \end{aligned} \]

    结合多项式求逆元得解。

代码

#include <iostream>
#include <cstdio>
#include <cstring>

namespace IO {
	const int N=1<<20;
	char buf[N],*l=buf,*r=buf;
	inline char gc() {
		if(l==r) r=(l=buf)+fread(buf,1,N,stdin);
		return l==r ? EOF : *(l++);
	}
	inline int read() {
		int x=0,s=1;
		char ch=gc();
		while(!isdigit(ch)) {if(ch=='-') s=-1;ch=gc();}
		while(isdigit(ch)) {x=x*10+(ch^48);ch=gc();}
		return x*s;
	}
}

using namespace std;
using IO::read;

const int maxn=4e5+10,mod=998244353,g=3,gn=(mod+1)/g,inv_2=(mod+1)/2;

int p=1,bit;
int f[maxn],h[maxn],c[maxn],rev[maxn];
int h_inv[maxn],c_r[maxn];

inline int power(long long a,int x) {
	long long ans=1;
	while(x) {
		if(x&1) ans=(ans*a)%mod;
		a=(a*a)%mod;
		x>>=1;
	}
	return ans;
}

inline void ntt(int *t,int inv) {
	for(int i=0;i<p;i++) {
		if(i<rev[i]) swap(t[i],t[rev[i]]);
	}
	for(int mid=1;mid<p;mid<<=1) {
		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
		int d=mid<<1;
		for(int l=0;l<p;l+=d) {
			int now=1;
			for(int i=0;i<mid;i++) {
				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
				t[l+i]=(x+y)%mod;
				t[l+mid+i]=(x-y+mod)%mod;
				now=(long long)now*unit%mod;
			}
		}
	}
	if(inv==-1) {
		int inver=power(p,mod-2);
		for(int i=0;i<p;i++) {
			t[i]=(long long)t[i]*inver%mod;
		}
	}
}

inline void getinv(int *f,int *h,int deg) {
	if(deg==1) {
		h[0]=power(f[0],mod-2);
		return ;
	}
	getinv(f,h,(deg+1)>>1);
	while(p<(deg<<1)) {p<<=1;bit++;}
	for(int i=1;i<p;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
	memcpy(c,f,deg*4);
	ntt(h,1),ntt(c,1);
	for(int i=0;i<p;i++) h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
	ntt(h,-1);
	for(int i=deg;i<p;i++) h[i]=0;
}

inline void getroot(int *f,int *h,int deg) {
	if(deg==1) {
		h[0]=1;
		return ;
	}
	getroot(f,h,(deg+1)>>1);
	p=1,bit=0;
	memset(h_inv,0,sizeof(h_inv));			//清零
	memset(c,0,sizeof(c));
	getinv(h,h_inv,deg);
	memcpy(c_r,f,deg*4);
	ntt(h,1),ntt(c_r,1),ntt(h_inv,1);
	for(int i=0;i<p;i++) h[i]=((long long)h[i]*h[i]%mod+c_r[i])*inv_2%mod*h_inv[i]%mod;
	ntt(h,-1);
	for(int i=deg;i<p;i++) h[i]=0;
}

int main() {
//#ifndef ONLINE_JUDGE
#ifdef LOCAL
	freopen("c.in","r",stdin);
	//freopen("c.out","w",stdout);
#endif
	//ios::sync_with_stdio(false);
	//cin.tie(0);cout.tie(0);
	int n=read();
	for(int i=0;i<n;i++) f[i]=read();
	getroot(f,h,n);
	for(int i=0;i<n;i++) printf("%d ",h[i]);
	putchar('\n');
	return 0;
}

// by pycr
  • \(Tips:\) 和之前一样,每次都需要清零。

——2021年2月8日

posted @ 2021-02-09 09:55  pycr  阅读(1053)  评论(0编辑  收藏  举报