常系数齐次线性递推

常系数齐次线性递推

定义

对于一个递推式,如果 \(a_n = \displaystyle \sum_{i=1}^{k}{a_{n-i}*f_i}\) ,那么称这个 \(a\) 序列满足 \(n\) 阶常系数齐次线性递推关系

矩阵优化

如果我们已知一个满足 \(k\) 阶常系数齐次线性递推关系的序列 \(a\) ,关系式为 \(a_n = \displaystyle \sum_{i=1}^{k}{a_{n-i} * f_i}\) ,要求出 \(a_n\) 的值

可以设计出一个转移矩阵进行矩阵优化

如果初始阵为

\[A= \begin{pmatrix} a_{n-1}\\ a_{n-2}\\ \vdots\\ a_{n-k} \end{pmatrix} \]

转移阵为

\[M= \begin{pmatrix} f_1 \quad &f_2 \quad &f_3 \quad &\dots \quad &f_{k-1}\\ 1 \quad &0 \quad &0 \quad &\dots \quad &0\\ 0 \quad &1 \quad &0 \quad &\dots \quad &0\\ \vdots \quad &\vdots \quad &\vdots \quad &\ddots \quad &\vdots \\ 0 \quad &0 \quad &0 \quad &\dots &1 \end{pmatrix} \]

那么 \(M \times A\) 可以得到矩阵

\[\begin{pmatrix} a_{n}\\ a_{n-1}\\ \vdots\\ a_{n-k-1} \end{pmatrix} \]

那么我们可以设计初始矩阵为

\[A= \begin{pmatrix} a_{k-1}\\ a_{k-2}\\ \vdots\\ a_{0} \end{pmatrix} \]

此时我们可以用 \(M^n \times A\) 来得到我们需要的矩阵

特征多项式

  • 若有常数 \(\lambda\) ,向量 \(\vec{v}\) ,满足 \(\lambda \vec{v} = A \vec{v}\) ,那么我们称 \(\lambda\) 为矩阵 \(A\) 的特征值,称 \(\vec{v}\) 为矩阵的特征向量

那么我们可以得到 \((\lambda I - A) \vec{v}= 0\) ,其中 \(0\) 表示零矩阵

此时该式有解当且仅当 \(det(\lambda I - A) = 0\)

这个行列式的展开形式为一个 \(k\) 次多项式,此时,我们称这个 \(k\) 次多项式为 \(A\) 的特征多项式,该多项式的值为 \(0\) 时的方程称为 \(A\) 的特征方程

记特征多项式为 \(f(x) = det(\lambda I - A)\) ,那么可以表示为 \(f(x) = \displaystyle \prod_{i}{\lambda_i - x}\)

凯莱-哈密顿定理 (Cayley-Hamilton定理)

  • 对于 \(A\) 的特征多项式 \(f(x)\) ,有 \(f(A) = 0\)

证明

\(f(A) =\displaystyle \prod_{i}{\lambda_i I - A}\)

对于这个 \(k\) 次的特征多项式,其有 \(k\) 个解,也就是说矩阵 \(A\)\(k\) 个特征值以及 \(k\) 个线性无关的特征向量,而如果 \(f(A)\) 得到的矩阵乘上任意一个特征向量都可以得到零矩阵,那么就可以推出 \(f(A)\) 为零矩阵

首先,可以证明, \((\lambda_i I - A)(\lambda_j I - A) = (\lambda_j I - A)(\lambda_i I - A)\)

那么

\[\begin{aligned} f(A) \times \vec{v_i} &= (\displaystyle \prod_{j}{\lambda_j I - A}) \times \vec{v_i} \\ &= (\displaystyle \prod_{j \neq i}{\lambda_j I - A}) \times ((\lambda_i I - A) \times \vec{v_i}) \end{aligned} \]

由特征值与特征向量的定义式可知: \((\lambda_i I - A) \vec{v_i} = 0\)

所以 \(\forall f(A) \times \vec{v_i} =0\)

得证

常系数齐次线性递推优化

设矩阵 \(M\) 的特征多项式为 \(f(x)\)

对于我们要求的 \(M^n\) ,可以写出

\[M^n = f(M) \times g(M) + R(M) \]

\(f(M)=0\) ,那么就有 \(M^n = R(M)\)

所以,我们只需要做 \(M^n ~\% ~f(M)\) 就可以了

考虑 \(f(M)\) 怎么求

按照定义 \(f(x) = \det(x I - M)\) ,所以这里有

\[f(x)= \begin{vmatrix} x-a_1 \quad &-a_2 \quad &-a_3 \quad &\dots \quad &-a_{k-1}\quad &-a_{k}\\ -1 \quad &x \quad &0 \quad &\dots \quad &0 \quad &0\\ 0 \quad &-1 \quad &x \quad &\dots \quad &0 \quad &0\\ \vdots \quad &\vdots \quad &\vdots \quad &\ddots \quad &\vdots \quad &\vdots\\ 0 \quad &0 \quad &0 \quad &\dots &-1 \quad &x \end{vmatrix} \]

将其进行展开,有

\[\begin{aligned} f(x) &= \displaystyle (x-a_1)M_{11} - a_2 M_{12} \dotsb - a_k M_{1k}\\ &= x^k - a_1 x^{k-1} - a_2 x^{k-2} - \dotsb a_k \end{aligned} \]

处理 \(M^n ~\%~ f(M)\) 我们可以在做快速幂的时候进行实现,所以这里的实现只需要在快速幂的时候做多项式取模即可,复杂度为 \(O(k^2 \log n)\)

而这里我们做快速幂的时候还会涉及多项式乘法,那么可以进行 NTT\FFT 优化,做到 \(O(k \log k \log n)\)

那么我们这里已经快速处理出了 \(M^n\) ,之后直接和初始的矩阵 \(A\) 相乘即可求得答案

例题

首先,恰好 \(K\) 个的概率不容易处理,可以考虑将其处理为至少有 \(K\) 个的概率减去至少有 \(K-1\) 个的概率

\(f_i\) 表示在底部的一个宽为 \(i\) 的矩形,并且第 \(i\) 个位置恰好为不合法的位置

那么最终答案就是 \(\frac{f_{n+1}}{1-q}\)

这里有 \(f_n = \displaystyle \sum_{i=1}^{n}{f_{n-i+1} * g_i}\) ,这里 \(g_i\) 表示出现长度宽为 \(i\) 的矩形的概率

\(dp_{i,j}\) 表示一个宽为 \(i\) ,高位 \(j\) 的矩形

那么这里 \(g_i = \displaystyle \sum_{j=1}^{\infty}{dp_{i,j}}\)

而这个 \(dp_{i,j}\) 实际上也是可以递推的,有递推式为

\[dp_{i,j} = [i*(j-1) \leq K] (1-q) q^{j-1} \displaystyle \sum_{k=1}^{i}{(\displaystyle \sum_{q > j} dp_{k-1,q})(\displaystyle \sum_{q \geq j}{dp_{i-k,q}})} \]

表示 \(dp_{i,j}\) 可以由宽 \(k-1\) 中那些高度大于 \(j\) 的矩形的情况在和宽 \(i-k\) ,高大于等于 \(j\) 的那些矩形拼起来,再乘上当前宽度为 \(i\) 的这个地方的高度只有 \(j\) 的部分的概率

这样,这里 \(i \times (j-1) \leq K\) ,所以对 \(i,j\) 的枚举的复杂度为 \(O(K \log K)\) ,再加上枚举 \(k,q\) 的枚举,复杂度为 \(O(K^2 \log^2 K)\)

而这个式子中对 \(q\) 枚举的部分是可以后缀和优化的(并且在 \(f\) 的求解中应用),那么此时求 \(dp\) 数组的复杂度可以被优化到 \(O(k^2 \log k)\)

\(f\) 数组的求解显然满足常系数其次线性递推的形式,可以直接套用优化,那么总复杂度为 \(O(k^2 \log k)\) (完全没有必要用 FFT\NTT 优化,直接暴力做多项式取模即可)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<cstring>
#define ll long long
#define ld long double

inline ll read()
{
	ll x=0,f=1;
	char ch=getchar();
	while(!isdigit(ch))
	{
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(isdigit(ch))
	{
		x=(x<<1)+(x<<3)+ch-'0';
		ch=getchar();
	}
	return x*f;
}

const ll inf=1e18;
const ll maxn=2e3+10;
const ll mod=998244353;
ll N,K,X,Y,p,q;
ll pw[maxn];
ll dp[maxn][maxn],sum[maxn][maxn];
ll I[maxn],A[maxn],M[maxn],ret[maxn],f[maxn];
ll tmp1[maxn],tmp2[maxn];

inline ll ksm(ll a,ll b,ll p)
{
	ll ret=1;
	while(b)
	{
		if(b&1) ret=ret*a%p;
		a=a*a%p;
		b>>=1;
	}
	return ret;
}

inline ll sol(ll x)
{
	ll ans=0;
	memset(M,0,sizeof(M));
	memset(A,0,sizeof(A));
	memset(f,0,sizeof(f));
	memset(I,0,sizeof(I));
	memset(dp,0,sizeof(dp));
	memset(sum,0,sizeof(sum));
	memset(ret,0,sizeof(ret));
	for(int i=0;i<=x+2;i++) sum[0][i]=dp[0][i]=1;
	for(int j=x;j>=1;j--)
	{
		for(int i=1;i*j<=x;i++)
		{
			for(int k=1;k<=i;k++)
			{
				(dp[i][j]+=sum[k-1][j+1]*sum[i-k][j]%mod*p%mod*pw[j]%mod)%=mod;
			}
			sum[i][j]=(sum[i][j+1]+dp[i][j])%mod;
		}
	}
//	for(int j=1;j<=x;j++)
//	{
//		for(int i=1;i*j<=x;i++)
//		{
//			printf("%d %d %lld %lld\n",i,j,dp[i][j],sum[i][j]);
//		}
//	}
	x++;
	for(int i=1;i<=x;i++) I[i]=sum[i-1][1]*p%mod;
	A[0]=1;
	for(int i=1;i<=x;i++)
	{
		for(int j=0;j<i;j++)
		{
			(A[i]+=A[j]*I[i-j]%mod)%=mod;
		}
	}
	for(int i=1;i<=x;i++) f[x-i]=mod-I[i];
	f[x]=1;
//	for(int i=0;i<=x;i++) printf("%lld ",I[i]);
//	putchar(10);
//	for(int i=0;i<=x;i++) printf("%lld ",A[i]);
//	putchar(10);
//	for(int i=0;i<=x;i++) printf("%lld ",f[i]);
//	putchar(10);
	ret[0]=1;
	M[1]=1;
	ll b=N+1;
	while(b)
	{
		if(b&1)
		{
			memcpy(tmp1,ret,sizeof(ret));
			memset(ret,0,sizeof(ret));
			for(int i=0;i<=x;i++)
			{
				for(int j=0;j<=x;j++)
				{
					(ret[i+j]+=M[i]*tmp1[j])%=mod;
				}
			}
			for(int i=2*x;i>=x;i--)
			{
				for(int j=0;j<=x;j++)
				{
					(ret[i+j-x]+=mod-ret[i]*f[j]%mod)%=mod;
				}
			}
		}
		memcpy(tmp1,M,sizeof(M));
		memcpy(tmp2,M,sizeof(M));
		memset(M,0,sizeof(M));
		for(int i=0;i<=x;i++)
		{
			for(int j=0;j<=x;j++)
			{
				(M[i+j]+=tmp1[i]*tmp2[j]%mod)%=mod;
			}
		}
		for(int i=2*x;i>=x;i--)
		{
			for(int j=0;j<=x;j++)
			{
				(M[i+j-x]+=mod-M[i]*f[j])%=mod;
			}
		}
		b>>=1;
	}
	for(int i=0;i<=x;i++) (ans+=ret[i]*A[i])%=mod;
//	printf("%lld\n",ans);
	return ans*ksm(p,mod-2,mod)%mod;
}

int main(void)
{
//	freopen("1.in","r",stdin);
//	freopen("1.ans","w",stdout);
	N=read(),K=read(),X=read(),Y=read();
	q=X*ksm(Y,mod-2,mod)%mod;
	p=(1-q+mod)%mod;
	pw[0]=1;
	for(int i=1;i<=K;i++) pw[i]=pw[i-1]*q%mod;
	printf("%lld\n",(sol(K)-sol(K-1)+mod)%mod);
	return 0;
}
posted @ 2021-09-02 21:05  雾隐  阅读(304)  评论(0编辑  收藏  举报