浅谈 BM 算法
BM 算法
BM 算法,全名 Berlekamp-Massey 算法,是一个可以 \(O(n^2)\) 求出一个数列的最短线性递推式的算法。其主要思想(大概)是一项一项加入,若不符合当前猜测的递推式则对其进行调整。
假设我们欲求数列 \({a_0,a_1,\cdots,a_n}\) 的最短线性递推式,设 \(r^{(i)}\) 是 \(a_0,\cdots,a_i\) 的最短线性递推式,\(l_i\) 为 \(r^{(i)}\) 的阶数(可以直观理解成项数 +1)。初始时令 \(r^{(-1)}=\{1\}\),然后每次考虑加入一个新的数,对每个前缀计算其最短线性递推式。
引理
先给出一个引理,其给出了每个前缀的线性递推式的最短长度:如果 \(r^{(i-1)}\) 不是 \(a_0,\cdots,a_i\) 的线性递推式,则 \(l_i\ge\max(l_{i-1},i+1-l_{i-1})\)。
证明:\(l_i\ge l_{i-1}\) 显然在任何情况下成立。
若有 \(l_i< i+1-l_{i-1}\),则对于所有 \(j\ge l_i\) 下式成立:
这也说明 \(r^{(i-1)}\) 是 \(a_0,\cdots,a_i\) 的递推式,与假设矛盾,得证。
有限数列
这样我们就拿到了下界,且由反证过程可以看到它是紧的。我们来考虑一下能不能构造到这个下界。
假设 \(A\) 是 \(a\) 的生成函数,\(R_i\) 是 \(r^{(i)}\) 的生成函数。
考虑新加入一个数 \(a_i\),若之前的递推式仍然可行(即 \(AR_{i-1}\equiv S_i\pmod {x^{i+1}}\),其中 \(S_i\) 的次数 \(< l_{i-1}\)),则直接沿用之前的递推式即可。反之,则会出现 \(AR_{i-1}\equiv S_{i-1}+cx^i\pmod {x^{i+1}}\),考虑怎么用原来的方案修正。考虑我们上一次修正(假设在 \(p\) 时刻)的时候有 \(AR_{p-1}\equiv S_{p-1}+dx^p\pmod{x^{p+1}}\),为了构造出 \(cx^i\),我们给等式左右两边同乘一个 \(x^{i-p}cd^{-1}\),即 \(x^{i-p}cd^{-1}AR_{p-1}\equiv x^{i-p}cd^{-1}S_{p-1}+cx^i\pmod{x^{i+1}}\);然后我们再将两式相减,可以得到 \(A(R_{i-1}-x^{i-p}cd^{-1}R_{p-1})\equiv S_{i-1}-x^{i-p}cd^{-1}S_{p-1}\pmod{x^{i+1}}\),即可得到 \(R_i=R_{i-1}-x^{i-p}cd^{-1}R_{p-1},S_i=S_{i-1}-x^{i-p}cd^{-1}S_{p-1}\)。
我们来归纳证明其取到了下界。假设有 \(l_p=\max(l_{p-1},p+1-l_{p-1})\),则由上面我们算出来的 \(R_i\) 的值有 \(l_i=\max(l_{i-1},i-p+l_{p-1})\)。若 \(l_p=l_{p-1}\),则继续往下归纳即可;否则有 \(l_p=p+1-l_{p-1}\),\(l_i=\max(l_{i-1},i-p+l_{p-1})=\max(l_{i-1},i+1-l_{p})\),得证。
无限数列
如果数列是无限数列,但我们知道 \(l_{\infty}\le s\),那么我们可以仅计算 \(a_0,\cdots,a_{2s}\) 的递推式,因为若我们在 \(t>2s\) 处更新了递推式,则会有 \(l_t=\max(l_{t-1},t+1-l_{t-1})>s\),与条件矛盾。
行/列向量列,矩阵列
假设欲计算 \(n\) 维行向量列 \(v_0,v_1,\cdots\) 的最短线性递推式,可以在模 \(p\) 意义下随机一个 \(n\) 维列向量 \(l\),则有 \(1-\frac{n}{p}\) 的概率使得 \(v_0l,v_1l,\cdots\) 的最短线性递推式就是 原行向量列的最短线性递推式。列向量同理。
矩阵列则可以随机一个 \(n\) 维列向量 \(u\) 与列向量 \(v\),计算 \(ub_0v,ub_1v,\cdots\) 的最短线性递推式即可,有 \(1-\frac{n+m}{p}\) 的概率出错。
概率好像是根据 Schwartz-Zippel 引理推知的,但是是怎么推来的嘛...不知道(
求矩阵最小多项式
\(n \times n\) 的矩阵 \(B\) 的最小多项式是次数最小的使得 \(f(B) = 0\) 的多项式 \(f\)。
对矩阵列 \(I,B,B^2,B^3,\cdots,B^{2n}\) 计算线性递推式即可,因为 \(B\) 的特征多项式满足 \(f(B)=0\),所以其最小多项式次数必定 \(\le n\)。
具体而言,我们需要计算 \(uIv,uBv,\cdots,uB^{2n}v\)。因为矩阵乘向量是 \(n^2\) 的,所以我们可以先 \(n^3\) 算出 \(uI,uB,\cdots,uB^{2n}\),再用 \(n^3\) 把 \(v\) 一一乘上去。
优化一类 dp
假设一类 dp 可以用矩阵快速幂计算(即形如 \(dp_{i,j}=\sum dp_{i-1,k}\cdot f_{k,j}\)),最后要求 \(\sum k_idp_{n,i}\) 或其他类似的东西,有一种万能的办法:把它的前 2k 项扔进 BM 算出它的递推式,然后直接 \(k^2\log n\) 或者 \(k\log k\log n\) 计算即可。因为矩阵是有最小多项式的,所以这个 dp 本质上也是一个线性递推,由 \(B\) 的最小多项式定义,所以其递推项数不超过 \(k\)。
知道了某个递推式,怎么快速算某一项?
code:
#include <bits/stdc++.h>
using namespace std;
int n , m , l[11000] , p;
long long a[11000] , r[11000] , rr[11000] , sav[11000] , las , noww , ans;
const int mod = 998244353;
long long exp( int a , int b )
{
long long ans = 1 , t = a;
while(b)
{
if(b & 1) ans = ans * t % mod;
t = t * t % mod; b >>= 1;
}
return ans;
}
struct poly
{
long long a[11000]; int len;
poly operator * ( const poly &x ) const
{
poly ans; memset(ans.a , 0 , sizeof(ans.a)); ans.len = len + x.len;
for(int i = 0 ; i <= len ; i++ )
{
for(int j = 0 ; j <= x.len ; j++ )
{
(ans.a[i + j] += a[i] * x.a[j] % mod) %= mod;
}
}
return ans;
}
poly operator % ( const poly &x ) const
{
poly ans; ans = (*this);
long long coe = exp(x.a[x.len] , mod - 2);
for(int i = ans.len ; i >= x.len ; i-- )
{
long long qwq = ans.a[i] * coe % mod;
for(int j = 0 ; j <= x.len ; j++ )
(ans.a[i - x.len + j] += mod - qwq * x.a[j] % mod) %= mod;
}
ans.len = min(ans.len , x.len - 1);
return ans;
}
} f , g , e;
poly exp( poly a , int b , poly p )
{
poly ans = e , t = a;
while(b)
{
if(b & 1) ans = (ans * t) % p;
t = (t * t) % p; b >>= 1;
}
return ans;
}
int main()
{
// freopen("1.in" , "r" , stdin);
// freopen("1.out" , "w" , stdout);
scanf("%d%d" , &n , &m);
for(int i = 0 ; i < n ; i++ ) scanf("%lld" , &a[i]); r[0] = 1; p = -1;
for(int i = 0 ; i < n ; i++ )
{
noww = 0; l[i] = l[i - 1];
for(int j = 0 ; j <= l[i] ; j++ ) (noww += r[j] * a[i - j] % mod) %= mod;
if(!noww) continue;
// cerr << las << ' ' << noww << endl;
long long coe = exp(las , mod - 2) * noww % mod;
if(!las)
{
memcpy(rr , r , sizeof(rr));
l[i] = i + 1; r[i + 1] = 1;
las = noww; p = i;
}
else
{
memcpy(sav , rr , sizeof(sav)); memcpy(rr , r , sizeof(rr));
l[i] = max(l[i - 1] , i + 1 - l[i - 1]);
for(int j = 0 ; j <= l[p - 1] ; j++ )
(r[i - p + j] += mod - coe * sav[j] % mod) %= mod;
p = i; las = noww;
}
// for(int j = 0 ; j <= l[i] ; j++ ) printf("%lld " , r[j]); printf("\n");
}
// for(int i = 0 ; i < n ; i++ ) cerr << l[i] << ' '; cerr << endl;
for(int i = 1 ; i <= l[n - 1] ; i++ ) printf("%lld " , (mod - r[i]) % mod); printf("\n");
e.a[0] = 1; g.a[1] = g.len = 1; f.len = l[n - 1];
for(int i = 0 ; i <= l[n - 1] ; i++ ) f.a[l[n - 1] - i] = r[i];
g = exp(g , m , f);
// for(int i = 0 ; i <= g.len ; i++ ) cerr << g.a[i] << ' '; cerr << endl;
for(int i = 0 ; i <= l[n - 1] ; i++ ) (ans += g.a[i] * a[i] % mod) %= mod;
printf("%lld" , ans);
return 0;
}
/*
*/