Bostan-Mori 算法

EI 哥哥科普到 OI 界的科技……最近出现了基于这个的多项式复合/复合逆复杂度的突破,所以今天去看了一下。

这是一个用于解决多项式有理分式的单项系数问题 \([x^n]\frac{P}{Q}\) 的算法。该算法在解决常系数齐次线性递推问题时,代码明显短很多,常数相较原方法极其优越。

首先我们考虑用一次卷积将线性递推数列的生成函数写成有理分式的形式 \(\frac{P}{Q}\)(其中 \(Q\) 是该线性递推特征多项式,卷一下初项截断得到 \(P\))。

然后考虑如何解决 \([x^n]\frac{P}{Q}\)。我们将其写成 \(\frac{P(x)Q(-x)}{Q(x)Q(-x)}\),而 \(W(x)=Q(x)Q(-x)\)\(W(x)=W(-x)\),所以其满足奇数次项都为零。

所以我们将 \(P(x)Q(-x)\) 奇偶次项分开,设 \(P(x)Q(-x)=E(x^2)+xO(x^2),W(x)=Q(x)Q(-x)=T(x^2)\),那么原式就是 \(\frac{E(x^2)}{T(x^2)}+x\frac{O(x^2)}{T(x^2)}\)

发现左边只会贡献到偶数次项,右边只会贡献到奇数次项。所以我们只需要往一边递归,然后将 \(n\) 变成 \(\lfloor \frac{n}{2}\rfloor\) 继续做就行了。每一层进行常数次长度为线性递推式长度 \(k\) 的卷积,复杂度 \(O(k\log k\log n)\)

如果直接实现上述做法,每一层有两次卷积,一共 6 次长度为 \(2k\) 的 DFT。事实上我们可以做到更优的常数。

我们考虑直接维护 \(P,Q\) 的点值形式 \(\hat P,\hat Q\)。每次考虑从 \(\hat P,\hat Q\) 推出 \(\text{DFT}(P(x)Q(-x))\)\(\text{DFT}(Q(x)Q(-x))\),然后再对这两个多项式次数折半。

不妨假设 \(k\) 是 2 的整幂。

首先我们要把 \(P,Q\) 的点值形式从长度至少为 \(k\) 拓展到长度至少为 \(2k\)\(\hat P',\hat Q'\)。这一部分当然可以直接做一次长度为 \(k\) 的 IDFT,然后再做一次长度为 \(2k\) 的 DFT,不过我们可以更好。

注意到 \(\hat A'_{2i}=\hat A_{i},\hat A'_{2i+1}=\sum_{j} (\omega_{2k}^j A_j) \omega_{k}^{ij}\),所以 IDFT 一次然后再做一次长度为 \(k\) 的 DFT 即可。

倍长两个点值数组后,考虑表示出 \(\text{DFT}(Q(-x))\)。注意到对于数组 \(A'_i=(-1)^i A_i\)\(\hat A'_x=\sum_i A_i (-1)^i\omega_{2k}^{ix}=\hat A_{x\oplus k}\)。于是就得到了 \(\text{DFT}(Q(-x))\) 的值。

那么我们点乘就可以得到 \(\text{DFT}(P(x)Q(-x))\)\(\text{DFT}(Q(x)Q(-x))\)。考虑如何提取奇/偶次项系数。做变换 \(B_i=\frac{A_i\pm (-1)^i A_i}{2}\) 可以消去 \(A_i\) 中的奇/偶次项。我们已经会求 \(A'_i=(-1)^i A_i\) 的 DFT 数组了。所以我们可以直接求出 \(\text{DFT}(B)\)。然后考虑将次数折半。对于只有偶次项的多项式,直接取前一半点值就行了,对于只有奇次项的多项式,取完前一半后还要除以 \(\omega_{2k}^i\)

综上所述,我们一次迭代只需要进行 4 次长度恰好为 \(k\) 的 DFT,而 DFT 本身常数较小。所以说现在常系数齐次线性递推也可以承受 2e5 的数据范围了!

不仅如此,这个算法还可以用于解决特殊形式的有理二元分式求值。

对于 \(F(x,y)=\frac{1}{1-yf(x)}\),我们如果要求 \([x^n]F(x,y)\),可以对 \(x\) 这一维施加 Bostan-Mori 算法,这样迭代 \(t\) 次后 \(x\) 的次数不超过 \(\lfloor \frac{n}{2^t}\rfloor\)\(y\) 的次数不超过 \(2^t\),所以直接进行二元多项式卷积,复杂度就是 \(O(n\log^2 n)\) 的了。

而对于多项式复合/复合逆问题,其转置问题恰好是二元多项式有理分式求值,由于我忘了转置原理所以说没有继续研究了,可以看 alpha1022 博客

常系数齐次线性递推代码:

#include <cstdio>
#include <algorithm>
#include <numeric>
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<22,stdin)),p1==p2?EOF:*p1++)
char buf[1<<22],*p1=buf,*p2=buf;
using namespace std;
int read(){
	char c=getchar();int x=0;bool f=0;
	while(c<48||c>57) f|=(c=='-'),c=getchar();
	do x=x*10+(c^48),c=getchar();
	while(c>=48&&c<=57);
	if(f) return -x;
	return x;
}
const int P=998244353;
typedef long long ll;
typedef unsigned long long ull;
const int N=1<<20;
int cw[N|1],rev[N],ccw[N|1];
int bit,len,ilen;
int n,k;
int qp(int a,int b=P-2){
	int res=1;
	while(b){
		if(b&1) res=(ll)res*a%P;
		a=(ll)a*a%P;b>>=1;
	}
	return res;
}
void init(int x){
	bit=-1;len=1;
	while(len<=x) ++bit,len<<=1;
	int w=qp(3,(P-1)>>(bit+1));
	cw[0]=cw[len]=1;
	for(int i=1;i<len;++i){
		rev[i]=(rev[i>>1]>>1)|((i&1)<<bit);
		cw[i]=(ll)cw[i-1]*w%P;
	}
	ilen=qp(len);
}
void DFT(int *F){
	static ull f[N];
	for(int i=0;i<len;++i) f[i]=F[rev[i]];
	for(int i=1,tt=len>>1;i<len;i<<=1,tt>>=1)
		for(int j=0;j<len;j+=(i<<1))
			for(int k=j,t=0;k<(j|i);++k,t+=tt){
				ull x=f[k],y=f[k|i]*cw[t]%P;
				f[k]+=y;f[k|i]=x+(P-y);
			}
	for(int i=0;i<len;++i) F[i]=f[i]%P;
}
int A[N],B[N],T[N];
void multi(int *F){
	for(int i=0;i<len;++i) T[i]=F[i];
	DFT(T);reverse(T+1,T+len);
	for(int i=0;i<len;++i) T[i]=(ll)T[i]*ccw[i]%P*ilen%P;
	DFT(T);for(int i=len-1;~i;--i) F[i<<1]=F[i],F[i<<1|1]=T[i];
}
int main(){
	n=read();k=read();
	A[0]=1;
	for(int i=1;i<=k;++i){
		A[i]=(-read())%P;
		if(A[i]<0) A[i]+=P;
	}
	for(int i=0;i<k;++i){
		B[i]=read()%P;
		if(B[i]<0) B[i]+=P;
	}
	init(k<<1);
	DFT(A);DFT(B);
	for(int i=0;i<len;++i) B[i]=(ll)B[i]*A[i]%P;
	for(int i=0;i<=len;++i) ccw[i]=cw[i];
	DFT(B);reverse(B+1,B+len);
	for(int i=0;i<(len>>1);++i) A[i]=A[i<<1],B[i]=(ll)B[i]*ilen%P;
	for(int i=(len>>1);i<len;++i) A[i]=B[i]=0;
	init(k);
	for(int i=k;i<len;++i) B[i]=0;
	DFT(B);
	while(n){
		multi(A);multi(B);
		for(int i=0;i<(len<<1);++i) B[i]=(ll)B[i]*A[i^len]%P;
		for(int i=0;i<len;++i){
			A[i]=(ll)A[i]*A[i|len]%P;
			if(n&1) (B[i]-=B[i|len])<0&&(B[i]+=P);
			else (B[i]+=B[i|len])>=P&&(B[i]-=P);
			if(B[i]&1) B[i]+=P;
			B[i]>>=1;
			A[i|len]=B[i|len]=0;
		}
		if(n&1){
			for(int i=0;i<len;++i) B[i]=(ll)B[i]*ccw[(len<<1)-i]%P;
		}
		n>>=1;
	}
	printf("%d\n",int(accumulate(B,B+len,0ll)%P*qp(accumulate(A,A+len,0ll)%P)%P));
	return 0;
}
posted @ 2024-03-18 22:52  yyyyxh  阅读(115)  评论(1编辑  收藏  举报