浅谈 BM 算法

BM 算法

BM 算法,全名 Berlekamp-Massey 算法,是一个可以 \(O(n^2)\)​ 求出一个数列的最短线性递推式的算法。其主要思想(大概)是一项一项加入,若不符合当前猜测的递推式则对其进行调整。

假设我们欲求数列 \({a_0,a_1,\cdots,a_n}\)​ 的最短线性递推式,设 \(r^{(i)}\)​ 是 \(a_0,\cdots,a_i\)​ 的最短线性递推式,\(l_i\)​ 为 \(r^{(i)}\)​ 的阶数(可以直观理解成项数 +1)。初始时令 \(r^{(-1)}=\{1\}\)​,然后每次考虑加入一个新的数,对每个前缀计算其最短线性递推式。

引理

先给出一个引理,其给出了每个前缀的线性递推式的最短长度:如果 \(r^{(i-1)}\) 不是 \(a_0,\cdots,a_i\) 的线性递推式,则 \(l_i\ge\max(l_{i-1},i+1-l_{i-1})\)

证明:\(l_i\ge l_{i-1}\) 显然在任何情况下成立。

若有 \(l_i< i+1-l_{i-1}\),则对于所有 \(j\ge l_i\) 下式成立:

\[\begin{aligned} a_i&=-\sum_{j=1}^{l_{i-1}}p_ja_{i-j}=\sum_{j=1}^{l_{i-1}}p_j\sum_{k=1}^{l_i}q_ka_{i-j-k}\ (l_i< i+1-l_{i-1}) \\&=\sum_{j=1}^{l_i}q_j\sum_{k=1}^{l_{i-1}}p_ka_{i-j-k}=-\sum_{j=1}^{l_i}q_ja_{i-j} \end{aligned} \]

这也说明 \(r^{(i-1)}\)​ 是 \(a_0,\cdots,a_i\)​​ 的递推式,与假设矛盾,得证。

有限数列

这样我们就拿到了下界,且由反证过程可以看到它是紧的。我们来考虑一下能不能构造到这个下界。

假设 \(A\)\(a\) 的生成函数,\(R_i\)\(r^{(i)}\) 的生成函数。

考虑新加入一个数 \(a_i\)​,若之前的递推式仍然可行(即 \(AR_{i-1}\equiv S_i\pmod {x^{i+1}}\)​,其中 \(S_i\)​ 的次数 \(< l_{i-1}\)​),则直接沿用之前的递推式即可。反之,则会出现 \(AR_{i-1}\equiv S_{i-1}+cx^i\pmod {x^{i+1}}\)​,考虑怎么用原来的方案修正。考虑我们上一次修正(假设在 \(p\)​ 时刻)的时候有 \(AR_{p-1}\equiv S_{p-1}+dx^p\pmod{x^{p+1}}\)​,为了构造出 \(cx^i\)​,我们给等式左右两边同乘一个 \(x^{i-p}cd^{-1}\)​,即 \(x^{i-p}cd^{-1}AR_{p-1}\equiv x^{i-p}cd^{-1}S_{p-1}+cx^i\pmod{x^{i+1}}\)​;然后我们再将两式相减,可以得到 \(A(R_{i-1}-x^{i-p}cd^{-1}R_{p-1})\equiv S_{i-1}-x^{i-p}cd^{-1}S_{p-1}\pmod{x^{i+1}}\)​,即可得到 \(R_i=R_{i-1}-x^{i-p}cd^{-1}R_{p-1},S_i=S_{i-1}-x^{i-p}cd^{-1}S_{p-1}\)​。

我们来归纳证明其取到了下界。假设有 \(l_p=\max(l_{p-1},p+1-l_{p-1})\),则由上面我们算出来的 \(R_i\) 的值有 \(l_i=\max(l_{i-1},i-p+l_{p-1})\)。若 \(l_p=l_{p-1}\),则继续往下归纳即可;否则有 \(l_p=p+1-l_{p-1}\)\(l_i=\max(l_{i-1},i-p+l_{p-1})=\max(l_{i-1},i+1-l_{p})\)​,得证。

无限数列

如果数列是无限数列,但我们知道 \(l_{\infty}\le s\),那么我们可以仅计算 \(a_0,\cdots,a_{2s}\) 的递推式,因为若我们在 \(t>2s\) 处更新了递推式,则会有 \(l_t=\max(l_{t-1},t+1-l_{t-1})>s\),与条件矛盾。

行/列向量列,矩阵列

假设欲计算 \(n\) 维行向量列 \(v_0,v_1,\cdots\) 的最短线性递推式,可以在模 \(p\) 意义下随机一个 \(n\) 维列向量 \(l\),则有 \(1-\frac{n}{p}\) 的概率使得 \(v_0l,v_1l,\cdots\)​​​ 的最短线性递推式就是 原行向量列的最短线性递推式。列向量同理。

矩阵列则可以随机一个 \(n\)​ 维列向量 \(u\)​ 与列向量 \(v\)​,计算 \(ub_0v,ub_1v,\cdots\)​ 的最短线性递推式即可,有 \(1-\frac{n+m}{p}\)​​ 的概率出错。

概率好像是根据 Schwartz-Zippel 引理推知的,但是是怎么推来的嘛...不知道(

求矩阵最小多项式

\(n \times n\) 的矩阵 \(B\) 的最小多项式是次数最小的使得 \(f(B) = 0\) 的多项式 \(f\)

对矩阵列 \(I,B,B^2,B^3,\cdots,B^{2n}\)​ 计算线性递推式即可,因为​ \(B\) 的特征多项式满足 \(f(B)=0\),所以其最小多项式次数必定 \(\le n\)

具体而言,我们需要计算 \(uIv,uBv,\cdots,uB^{2n}v\)。因为矩阵乘向量是 \(n^2\) 的,所以我们可以先 \(n^3\) 算出 \(uI,uB,\cdots,uB^{2n}\),再用 \(n^3\)\(v\) 一一乘上去。

优化一类 dp

假设一类 dp 可以用矩阵快速幂计算(即形如 \(dp_{i,j}=\sum dp_{i-1,k}\cdot f_{k,j}\)),最后要求 \(\sum k_idp_{n,i}\) 或其他类似的东西,有一种万能的办法:把它的前 2k 项扔进 BM 算出它的递推式,然后直接 \(k^2\log n\) 或者 \(k\log k\log n\) 计算即可。因为矩阵是有最小多项式的,所以这个 dp 本质上也是一个线性递推,由 \(B\) 的最小多项式定义,所以其递推项数不超过 \(k\)

知道了某个递推式,怎么快速算某一项?

Cayley-Hamilton 定理

code:

#include <bits/stdc++.h>
using namespace std;
int n , m , l[11000] , p;
long long a[11000] , r[11000] , rr[11000] , sav[11000] , las , noww , ans;
const int mod = 998244353;
long long exp( int a , int b )
{
	long long ans = 1 , t = a;
	while(b)
	{
		if(b & 1) ans = ans * t % mod;
		t = t * t % mod; b >>= 1;
	}
	return ans;
}
struct poly
{
	long long a[11000]; int len;
	poly operator * ( const poly &x ) const
	{
		poly ans; memset(ans.a , 0 , sizeof(ans.a)); ans.len = len + x.len;
		for(int i = 0 ; i <= len ; i++ )
		{
			for(int j = 0 ; j <= x.len ; j++ )
			{
				(ans.a[i + j] += a[i] * x.a[j] % mod) %= mod;
			}
		}
		return ans;
	}
	poly operator % ( const poly &x ) const
	{
		poly ans; ans = (*this);
		long long coe = exp(x.a[x.len] , mod - 2);
		for(int i = ans.len ; i >= x.len ; i-- )
		{
			long long qwq = ans.a[i] * coe % mod;
			for(int j = 0 ; j <= x.len ; j++ ) 
				(ans.a[i - x.len + j] += mod - qwq * x.a[j] % mod) %= mod;
		}
		ans.len = min(ans.len , x.len - 1);
		return ans;
	} 
} f , g , e;
poly exp( poly a , int b , poly p )
{
	poly ans = e , t = a;
	while(b)
	{
		if(b & 1) ans = (ans * t) % p;
		t = (t * t) % p; b >>= 1;
	}
	return ans;
}
int main() 
{
//	freopen("1.in" , "r" , stdin);
//	freopen("1.out" , "w" , stdout);
	scanf("%d%d" , &n , &m);
	for(int i = 0 ; i < n ; i++ ) scanf("%lld" , &a[i]); r[0] = 1; p = -1;
	for(int i = 0 ; i < n ; i++ )
	{
		noww = 0; l[i] = l[i - 1];
		for(int j = 0 ; j <= l[i] ; j++ ) (noww += r[j] * a[i - j] % mod) %= mod;
		if(!noww) continue;
//		cerr << las << ' ' << noww << endl;
		long long coe = exp(las , mod - 2) * noww % mod;
		if(!las)
		{
			memcpy(rr , r , sizeof(rr));
			l[i] = i + 1; r[i + 1] = 1;
			las = noww; p = i;
		} 
		else
		{
			memcpy(sav , rr , sizeof(sav)); memcpy(rr , r , sizeof(rr));
			l[i] = max(l[i - 1] , i + 1 - l[i - 1]);
			for(int j = 0 ; j <= l[p - 1] ; j++ )
				(r[i - p + j] += mod - coe * sav[j] % mod) %= mod;
			p = i; las = noww;
		}
//		for(int j = 0 ; j <= l[i] ; j++ ) printf("%lld " , r[j]); printf("\n");
	}
//	for(int i = 0 ; i < n ; i++ ) cerr << l[i] << ' '; cerr << endl;
	for(int i = 1 ; i <= l[n - 1] ; i++ ) printf("%lld " , (mod - r[i]) % mod); printf("\n");
	e.a[0] = 1; g.a[1] = g.len = 1; f.len = l[n - 1];
	for(int i = 0 ; i <= l[n - 1] ; i++ ) f.a[l[n - 1] - i] = r[i];
	g = exp(g , m , f);
//	for(int i = 0 ; i <= g.len ; i++ ) cerr << g.a[i] << ' '; cerr << endl;
	for(int i = 0 ; i <= l[n - 1] ; i++ ) (ans += g.a[i] * a[i] % mod) %= mod;
	printf("%lld" , ans);
    return 0; 
}
/*
*/
posted @ 2022-02-13 10:41  恨妹不成穹  阅读(1734)  评论(0编辑  收藏  举报