LG5487 【模板】线性递推+BM算法
【模板】线性递推+BM算法
给出一个数列 \(P\) 从 \(0\) 开始的前 \(n\) 项,求序列 \(P\) 在\(\bmod~998244353\) 下的最短线性递推式,并在 \(\bmod~ 998244353\) 下输出 \(P_m\)。
\(m\leq 10^9,1\leq n\leq 10000\),保证递推式最长不超过 \(5000\)。
Berlekamp-Massey 算法
Berlekamp-Massey 算法,常简称为 BM 算法,是用来求解一个数列的最短线性递推式的算法。
BM 算法可以在 \(O(n^2)\) 的时间内求解一个长度为 \(n\) 的数列的最短线性递推式。
基本定义
对于数列 \(A=\{a_1,a_2,\dots,a_n\}\),我们定义数列 \(R=\{r_1,r_2,\dots,r_m\}\) 为其线性递推式当且仅当
注意,无论你习惯从左到右写数列还是从右到左,这里的数列和线性递推式的位置对应关系是反着的。
所有可能的线性递推式 \(R\) 中长度 \(m\) 最小的叫做 \(A\) 的最短线性递推式。
算法流程
假设我们已经求得了 \(\{a_1,a_2,\dots,a_{i-1}\}\) 的最短线性递推式 \(\{r_1,r_2,\dots,r_m\}\),那么如何求得 \(\{a_1,a_2,\dots,a_i\}\) 的最短线性递推式?
定义 \(\{a_1,a_2,\dots,a_{i-1}\}\) 的最短线性递推式 \(\{r_1,r_2,\dots,r_m\}\) 为当前递推式,记递推式被更改的次数为 \(cnt\),第 \(i\) 次更改后的递推式为 \(R_i\),那么当前递推式应当为 \(R_{cnt}\)。特别地,定义 \(R_0=\varnothing\)。
我们对每个版本的 \(R\),记一个表示差异量的数组 \(\Delta_i\),满足
显然若 \(\Delta_i=0\),那么当前递推式就是 \(\{a_1,a_2,\dots,a_i\}\) 的最短线性递推式。
否则我们认为 \(R_{cnt}\) 在 \(a_i\) 处出错了,令 \(fail_i\) 为 \(R_i\) 最早的出错位置,则有 \(fail_{cnt}=i\)。考虑对 \(R_{cnt}\) 进行修改,使其变为 \(R_{cnt+1}\),并在 \(a_i\) 处同样成立。
若当前 \(cnt=0\),说明 \(a_i\) 是第一个非零元素,直接将 \(R_1\) 置为 \(\{ \underbrace{0,0,\dots,0}_{i} \}\) 即可,因为不可能用之前的数递推出 \(a_i\)。
否,即 \(cnt>0\),则考虑用 \(R_{cnt}\) 之前失败的递推式将这个 \(\Delta_i\) 加回去(\(a=\sum+\Delta\))。我们希望得到数列 \(R'=\{r'_1,r'_2,...,r'_{m'}\}\),使得
-
\[\forall k\in [m'+1,i-1],~\sum_{j=1}^{m'}r'_ja_{k-j}=0 \]
-
\[\sum_{j=1}^{m'}r'_ja_{i-j}=\Delta_i \]
如果能够找到这样的数列 \(R'\),那么令\(R_{cnt+1}=R_{cnt}+R'\)即可。这里加号表示各位对应相加。
在之前失败的递推式中任选一个 \(R_p\),尝试在它的基础上修改,在 \(i\) 的位置上构造出 \(\Delta_{fail_p}\)(这里的 \(\Delta\) 是对应 \(R_p\) 版本的),记得到的结果为 \(R_p'\),那么
考虑如何构造 \(R'_p\) 。将 \(R_p\) 的元素全部变成它的相反数,再在前面补上一个 \(1\) , \(\Delta_{fail_p}\) 就到 \(fail_p+1\) 位置上来了。
\[a_{fail_p}-\sum_{i=1}^{m_p}R_{p,i}a_{fail_p-i}=\Delta_{p,fail_p} \]
前面再补上 \(i-fail_p-1\) 个 \(0\),\(\Delta_{fail_p}\) 就到 \(i\) 位置上来了。于是
这里乘号表示顺次连接。
又因为 \(R_p\) 在 \(fail_p\) 前的 \(\Delta=0\),所以我们构造出来的 \(R'\) 是满足第一条约束的。
为了保证得到的递推式长度最短,我们需要选取恰当的 \(R_p\)。容易看出,得到的 \(R_{cnt+1}\) 的长度为 \(\max(i-fail_p+m_p,m)\)。于是记录 \(m_p-fail_p\) 最短的递推式作为 \(R_p\) 。
至此我们完成了 BM 算法的理论部分,在最坏情况下,我们可能需要对数列进行 \(O(n)\) 次修改,因此该算法的时间复杂度为 \(O(n^2)\)。
经验之谈
用 BM 得到的最短递推式长度最好要明显小于 \(n\) 的一半,否则需要再打些表。
为什么?因为若长度为 \(\frac n 2\),可以看做 \(\frac n 2\) 个变量列出 \(\frac n 2\) 个方程,总能找到解。所以一个随机数列解出的最短递推式长度就是 \(\frac n 2\) 左右。发生了这样的情况说明原数列很可能并没有一定的规律,即递推式大概率对之后的数据不适用。
另外因为计算中涉及除法,所以 BM 在实数域内求解可能有一定的精度误差。
namespace linear{
typedef vector<int> polynomial;
void num_trans(polynomial&a,int dir){
int lim=a.size();
static vector<int> rev,w[2];
if(rev.size()!=lim){
rev.resize(lim);
int len=log2(lim);
for(int i=0;i<lim;++i) rev[i]=rev[i>>1]>>1|(i&1)<<(len-1);
for(int dir=0;dir<2;++dir){
static co int g[2]={3,332748118};
w[dir].resize(lim);
w[dir][0]=1,w[dir][1]=fpow(g[dir],(mod-1)/lim);
for(int i=2;i<lim;++i) w[dir][i]=mul(w[dir][i-1],w[dir][1]);
}
}
for(int i=0;i<lim;++i)if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int step=1;step<lim;step<<=1){
int quot=lim/(step<<1);
for(int i=0;i<lim;i+=step<<1){
int j=i+step;
for(int k=0;k<step;++k){
int t=mul(w[dir][quot*k],a[j+k]);
a[j+k]=add(a[i+k],mod-t),a[i+k]=add(a[i+k],t);
}
}
}
if(dir){
int ilim=fpow(lim,mod-2);
for(int i=0;i<lim;++i) a[i]=mul(a[i],ilim);
}
}
polynomial poly_inv(polynomial a,int n){
polynomial b(1,fpow(a[0],mod-2));
if(n==1) return b;
int lim=2;
for(;lim<n;lim<<=1){
polynomial a1(a.begin(),a.begin()+lim);
a1.resize(lim<<1),num_trans(a1,0);
b.resize(lim<<1),num_trans(b,0);
for(int i=0;i<lim<<1;++i) b[i]=mul(add(2,mod-mul(a1[i],b[i])),b[i]);
num_trans(b,1),b.resize(lim);
}
a.resize(lim<<1),num_trans(a,0);
b.resize(lim<<1),num_trans(b,0);
for(int i=0;i<lim<<1;++i) b[i]=mul(add(2,mod-mul(a[i],b[i])),b[i]);
num_trans(b,1),b.resize(n);
return b;
}
polynomial operator/(polynomial f,polynomial g){
int n=f.size()-1,m=g.size()-1;
reverse(g.begin(),g.end()),g.resize(n-m+1),g=poly_inv(g,n-m+1);
reverse(f.begin(),f.end()),f.resize(n-m+1);
int lim=1<<int(ceil(log2((n-m)<<1|1)));
f.resize(lim),num_trans(f,0);
g.resize(lim),num_trans(g,0);
for(int i=0;i<lim;++i) f[i]=mul(f[i],g[i]);
num_trans(f,1),f.resize(n-m+1);
return reverse(f.begin(),f.end()),f;
}
polynomial operator%(polynomial f,polynomial g){
int n=f.size()-1,m=g.size()-1;
polynomial q=f/g;
int lim=1<<int(ceil(log2(n+1)));
q.resize(lim),num_trans(q,0);
g.resize(lim),num_trans(g,0);
for(int i=0;i<lim;++i) q[i]=mul(q[i],g[i]);
num_trans(q,1);
for(int i=0;i<m;++i) f[i]=add(f[i],mod-q[i]);
return f.resize(m),f;
}
int n,k;
void mul_mod(polynomial&a,polynomial b,co polynomial&p){
static co int lim=1<<int(ceil(log2(2*k-1)));
a.resize(lim),b.resize(lim);
num_trans(a,0),num_trans(b,0);
for(int i=0;i<lim;++i) a[i]=mul(a[i],b[i]);
num_trans(a,1),a.resize(2*k-1);
a=a%p;
}
void main(int _n,int _k,co vector<int>&_a,co vector<int>&_f){
n=_n,k=_k;
polynomial a(k),f(k);
for(int i=1;i<=k;++i) a[k-i]=mod-_a[i];
a.push_back(1);
for(int i=0;i<k;++i) f[i]=_f[i];
polynomial rmd(1,1),tmp(2);tmp[1]=1;
for(;n;n>>=1,mul_mod(tmp,tmp,a))
if(n&1) mul_mod(rmd,tmp,a);
int ans=0;
for(int i=0;i<k;++i) ans=add(ans,mul(rmd[i],f[i]));
printf("%d\n",ans);
}
}
vector<int> ber_ma(vector<int> f){
vector<int> lst,cur;
int lsfa,lsdel;
for(int i=0;i<(int)f.size();++i){
int del=f[i];
for(int j=1;j<(int)cur.size();++j)
del=add(del,mod-mul(cur[j],f[i-j]));
if(!del) continue;
if(!cur.size()){
cur.resize(i+1),lsfa=i,lsdel=del;
continue;
}
int alph=mul(del,fpow(lsdel,mod-2));
vector<int> nw(i-lsfa);
nw.push_back(alph);
for(int j=1;j<(int)lst.size();++j)
nw.push_back(mul(alph,mod-lst[j]));
if(nw.size()<cur.size()) nw.resize(cur.size());
for(int j=1;j<(int)cur.size();++j)
nw[j]=add(nw[j],cur[j]);
if(i-lsfa+(int)lst.size()>=(int)cur.size())
lst=cur,lsfa=i,lsdel=del;
cur=nw;
}
return cur;
}
int main(){
int n=read<int>(),m=read<int>();
vector<int> f(n);
for(int i=0;i<n;++i) read(f[i]);
vector<int> a=ber_ma(f);
for(int i=1;i<(int)a.size();++i) printf("%d ",a[i]);
puts("");
if(m<=n) {printf("%d\n",f[m]);return 0;}
linear::main(m,a.size()-1,a,f);
return 0;
}
线性递推式是 base 1 的,用 vector 存的话代码有点奇怪。