算法记录 001:用多项式的逆优化dp总结

用多项式的逆优化dp总结

考虑一个经典的模型:

\(dp_i=\sum_{j=1}^i dp_{i-j}\times f_j\)

\(dp_n\) 这种问题表面看上去需要\(n^2\),但是事实上可以更优。

我们建立一个长度无穷大的多项式:\(F(x)=\sum _{i=0}^\infty f_{i}\times x^i\),特殊的:\(f_{0}=f_{k+n}=0,k\in \mathbb{N}^*\)

然后可以发现\(dp_n=\sum _{i=0}^{\infty}[x^n](F^i(x))\)

由于\(\sum _{i=0}^{\infty} x^i=\frac{1}{1-x}\)

所以:\(dp_{n}=[x^n]\frac{1}{1-F(x)}\)

然后\(dp_n\)就是多项式\(G(x)=1-F(x)\)的逆的第n项系数。

如何求多项式的逆元?

题解 P4238 【模板】多项式求逆- Great_Influence 的博客 - 洛谷博客 (luogu.com.cn)

my code:

const int MOD=998244353;
const int g=3;
int len;
int rev[1<<19];
void butterfly(vector<int> & v){
	rep(i,len){
		rev[i]=rev[i>>1]>>1;
		if(i&1) rev[i]|=len>>1; 
	}
	rep(i,len) if(rev[i]>i) swap(v[i],v[rev[i]]);
}
LL quick(LL A,LL B){
	if(B==0) return 1;
	LL  tmp=quick(A,B>>1);
	tmp*=tmp;
	tmp%=MOD;
	if(B&1)
		tmp*=A,tmp%=MOD;
 	return tmp;
}
int inv(int x){
	return quick(x,MOD-2);
}
vector<int> ntt(vector<int> v,int ty){
	butterfly(v);
	vector<int> nex;
	for(int l=2;l<=len;l<<=1){
		nex.clear();
		nex.resize(len);
		int step=quick(g,(MOD-1)/l);
		if(ty==-1) step=inv(step); 
		for(int j=0;j<len;j+=l){
			int now=1;
			for(int k=0;k<l/2;++k){
				int A,B;
				A=v[j+k];
				B=v[j+l/2+k];
				B=1ll*now*B%MOD;
				nex[j+k]=(A+B)%MOD;
				nex[j+k+l/2]=(A-B+MOD)%MOD;
				now=1ll*now*step%MOD;
			}
		}
		v=nex;
	}
	return v;
}
int get(int x){
	int y=1;
	while(y<x) y<<=1;
	return y;
}
vector<int> mul(vector<int> A,vector<int> B){
	len=get(A.size()+B.size());
	A.resize(len);
	B.resize(len);
	A=ntt(A,1);
	B=ntt(B,1);
	rep(i,len) A[i]=1ll*A[i]*B[i]%MOD;
	A=ntt(A,-1);
	int inve=quick(len,MOD-2);
	rep(i,len) A[i]=1ll*A[i]*inve%MOD;
	return A;
}
vector<int> inverse(vector<int> A,int n){
	//计算% x^n 的逆元
	vector<int> ret(n);
	if(n==1){
		ret[0]=quick(A[0],MOD-2);
	}
	else{
		vector<int> oth=inverse(A,(n+1)>>1);
		ret=oth;
		ret.resize(n);
		rep(i,n) ret[i]=2ll*ret[i]%MOD;
		oth=mul(oth,oth);
		oth.resize(n);
		oth=mul(oth,A);
		oth.resize(n);
		rep(i,n) ret[i]=(ret[i]-oth[i]+MOD)%MOD;
	}
	return ret;
}
int main(){
//	freopen("test.in","r",stdin);
//	freopen("test.out","w",stdout);
	vector<int> poly;
	int n;
	scanf("%d",&n);
	poly.resize(n);
	rep(i,n) scanf("%d",&poly[i]);
	vector<int> inver=inverse(poly,n);
	rep(i,n) printf("%d ",inver[i]);
	cout<<endl;
	return 0;
}
posted @ 2020-12-31 23:42  WWW~~~  阅读(120)  评论(0编辑  收藏  举报