Berlekamp-Massey算法学习笔记
用途
\(O(n^2)\)求解一个长度为\(n\)的数列的最短线性递推式。
一般可以用于猜结论/骗分。
思想
从左到右依次扫过去,每当出现一个元素不符合原来的递推式时就修正它,得到新的递推式。
记当前已经有的递推式共有\(cnt\)个,第\(i\)个递推式出错的位置是\(fail_i\),出错时原数列与算出的结果的差是\(delta_i\),第\(i\)个递推式记作\(R_i\),把\(R\)代入\(i\)位置算出的结果为\(calc(i,R)\),于是有\(delta_i=a_{fail_i}-calc(fail_i,R_i)\)。
一开始的递推式是\(R_0=\{\}\),也就是空数列,\(cnt\)也为0。
若\(R_{cnt}\)在第\(i\)个位置出错了,那么可以得到\(delta_{cnt}\),记\(fail_{cnt}=i\)。
首先特判\(cnt=0\)。若\(cnt=0\),也就是之前一直是全\(0\)数列,那么直接设\(R_1=\{0,0,\cdots,0\}\),也就是用\(i\)个\(0\)填充,然后接着往后扫。
否则,把之前的某个\(R\)搬过来作为基准(记为\(R_p\))设\(mul=\frac{delta_{cnt}}{delta_p}\)。
请注意:\(p\)不能简单地取\(cnt-1\),否则不能保证递推式最短。
(hack方式:n=10,a={1,2,3,4,5,1,2,3,4,5}
,错误代码的递推式将会非常难看,而正确答案是0,0,0,0,1
)
现在希望得到一个\(R'\),使得\(R'\)在\(j<fail_{cnt}\)时有\(calc(j,R')=0\),且刚好有\(calc(fail_{cnt},R')=delta_i\),那么就可以得到新数列\(R_{cnt+1}=R_{cnt}+R'\)了。
\(R'\)怎么求?考虑把\(R_p\)搬过来,并在左边加上一个-1,也就是\(\{-1,R_p\}\)。容易发现,它在\(j< fail_p\)时有\(calc(j,R')=0\),并且\(calc(fail_{p},R')=-delta_{p}\)。那么把\(R'\)再乘一个\(-mul\)就可以使得\(calc(fail_{p},R')=delta_{cnt}\)。
但是它的位置好像有一些不对?我们想要的是\(calc(fail_{cnt},R')=delta_{cnt}\)啊。
这个简单。只需要再在数列左边添上\(fail_{cnt}-fail_p-1\)个0,相当于是平移了一下。
于是最后得到\(R'=\{0,0,\cdots,0,mul,-mul\times R_{p}\}\),令\(R_{cnt+1}=R_{cnt}+R'\),就完成了更新。
那么\(p\)究竟应该如何取呢?需要选取一个加完0之后长度最短的递推式,具体见代码。
至于为什么这样一定是最短的递推式,我也不知道qwq
复杂度:最坏情况下要更新\(n\)次,所以复杂度\(O(n^2)\)。
代码
#include<bits/stdc++.h>
clock_t t=clock();
namespace my_std{
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
#define MP make_pair
#define rep(i,x,y) for (int i=(x);i<=(y);i++)
#define drep(i,x,y) for (int i=(x);i>=(y);i--)
#define go(x) for (int i=head[x];i;i=edge[i].nxt)
#define templ template<typename T>
#define sz 2020
typedef long long ll;
typedef double db;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
templ inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
templ inline bool chkmax(T &x,T y){return x<y?x=y,1:0;}
templ inline bool chkmin(T &x,T y){return x>y?x=y,1:0;}
templ inline void read(T& t)
{
t=0;char f=0,ch=getchar();double d=0.1;
while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
t=(f?-t:t);
}
template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
char __sr[1<<21],__z[20];int __C=-1,__zz=0;
inline void Ot(){fwrite(__sr,1,__C+1,stdout),__C=-1;}
inline void print(register int x)
{
if(__C>1<<20)Ot();if(x<0)__sr[++__C]='-',x=-x;
while(__z[++__zz]=x%10+48,x/=10);
while(__sr[++__C]=__z[__zz],--__zz);__sr[++__C]='\n';
}
void file()
{
#ifdef NTFOrz
freopen("a.in","r",stdin);
#endif
}
inline void chktime()
{
#ifndef ONLINE_JUDGE
cout<<(clock()-t)/1000.0<<'\n';
#endif
}
#ifdef mod
ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;return ret;}
ll inv(ll x){return ksm(x,mod-2);}
#else
ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;return ret;}
#endif
// inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std;
int n;
db a[sz];
int fail[sz],cnt;
vector<db>R[sz];
db delta[sz];
int main()
{
file();
read(n);
rep(i,1,n) read(a[i]);
int bst=0;
rep(i,1,n)
{
db cur=a[i];
rep(j,0,(int)R[cnt].size()-1) cur-=R[cnt][j]*a[i-j-1];
if (fabs(cur)<1e-7) continue;
delta[cnt]=cur;fail[cnt]=i;++cnt;
if (cnt==1){R[cnt].resize(i);continue;}
db mul=delta[cnt-1]/delta[bst];
vector<db>tmp;
tmp.resize(i-fail[bst]-1);tmp.push_back(mul);
rep(j,0,(int)R[bst].size()-1) tmp.push_back(-R[bst][j]*mul);
R[cnt]=tmp;if (R[cnt-1].size()>tmp.size()) R[cnt].resize(R[cnt-1].size());
rep(j,0,(int)R[cnt-1].size()-1) R[cnt][j]+=R[cnt-1][j];
if (i-fail[bst]+R[bst].size()>R[cnt-1].size()) bst=cnt-1;
}
rep(i,0,(int)R[cnt].size()-1) printf("%.5lf ",R[cnt][i]);
return 0;
}