Berlekamp-Massey算法学习笔记

用途

\(O(n^2)\)求解一个长度为\(n\)的数列的最短线性递推式。

一般可以用于猜结论/骗分。

思想

从左到右依次扫过去,每当出现一个元素不符合原来的递推式时就修正它,得到新的递推式。

记当前已经有的递推式共有\(cnt​\)个,第\(i​\)个递推式出错的位置是\(fail_i​\),出错时原数列与算出的结果的差是\(delta_i​\),第\(i​\)个递推式记作\(R_i​\),把\(R​\)代入\(i​\)位置算出的结果为\(calc(i,R)​\),于是有\(delta_i=a_{fail_i}-calc(fail_i,R_i)​\)

一开始的递推式是\(R_0=\{\}\),也就是空数列,\(cnt\)也为0。

\(R_{cnt}\)在第\(i\)个位置出错了,那么可以得到\(delta_{cnt}\),记\(fail_{cnt}=i\)

首先特判\(cnt=0\)。若\(cnt=0\),也就是之前一直是全\(0\)数列,那么直接设\(R_1=\{0,0,\cdots,0\}\),也就是用\(i\)\(0\)填充,然后接着往后扫。

否则,把之前的某个\(R\)搬过来作为基准(记为\(R_p\))设\(mul=\frac{delta_{cnt}}{delta_p}\)

请注意:\(p\)不能简单地取\(cnt-1\),否则不能保证递推式最短。

(hack方式:n=10,a={1,2,3,4,5,1,2,3,4,5},错误代码的递推式将会非常难看,而正确答案是0,0,0,0,1

现在希望得到一个\(R'​\),使得\(R'​\)\(j<fail_{cnt}​\)时有\(calc(j,R')=0​\),且刚好有\(calc(fail_{cnt},R')=delta_i​\),那么就可以得到新数列\(R_{cnt+1}=R_{cnt}+R'​\)了。

\(R'​\)怎么求?考虑把\(R_p​\)搬过来,并在左边加上一个-1,也就是\(\{-1,R_p\}​\)。容易发现,它在\(j< fail_p​\)时有\(calc(j,R')=0​\),并且\(calc(fail_{p},R')=-delta_{p}​\)。那么把\(R'​\)再乘一个\(-mul​\)就可以使得\(calc(fail_{p},R')=delta_{cnt}​\)

但是它的位置好像有一些不对?我们想要的是\(calc(fail_{cnt},R')=delta_{cnt}\)啊。

这个简单。只需要再在数列左边添上\(fail_{cnt}-fail_p-1​\)个0,相当于是平移了一下。

于是最后得到\(R'=\{0,0,\cdots,0,mul,-mul\times R_{p}\}\),令\(R_{cnt+1}=R_{cnt}+R'\),就完成了更新。

那么\(p\)究竟应该如何取呢?需要选取一个加完0之后长度最短的递推式,具体见代码。

至于为什么这样一定是最短的递推式,我也不知道qwq

复杂度:最坏情况下要更新\(n\)次,所以复杂度\(O(n^2)\)

代码

#include<bits/stdc++.h>
clock_t t=clock();
namespace my_std{
	using namespace std;
	#define pii pair<int,int>
	#define fir first
	#define sec second
	#define MP make_pair
	#define rep(i,x,y) for (int i=(x);i<=(y);i++)
	#define drep(i,x,y) for (int i=(x);i>=(y);i--)
	#define go(x) for (int i=head[x];i;i=edge[i].nxt)
	#define templ template<typename T>
	#define sz 2020
	typedef long long ll;
	typedef double db;
	mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
	templ inline T rnd(T l,T r) {return uniform_int_distribution<T>(l,r)(rng);}
	templ inline bool chkmax(T &x,T y){return x<y?x=y,1:0;}
	templ inline bool chkmin(T &x,T y){return x>y?x=y,1:0;}
	templ inline void read(T& t)
	{
		t=0;char f=0,ch=getchar();double d=0.1;
		while(ch>'9'||ch<'0') f|=(ch=='-'),ch=getchar();
		while(ch<='9'&&ch>='0') t=t*10+ch-48,ch=getchar();
		if(ch=='.'){ch=getchar();while(ch<='9'&&ch>='0') t+=d*(ch^48),d*=0.1,ch=getchar();}
		t=(f?-t:t);
	}
	template<typename T,typename... Args>inline void read(T& t,Args&... args){read(t); read(args...);}
	char __sr[1<<21],__z[20];int __C=-1,__zz=0;
	inline void Ot(){fwrite(__sr,1,__C+1,stdout),__C=-1;}
	inline void print(register int x)
	{
		if(__C>1<<20)Ot();if(x<0)__sr[++__C]='-',x=-x;
		while(__z[++__zz]=x%10+48,x/=10);
		while(__sr[++__C]=__z[__zz],--__zz);__sr[++__C]='\n';
	}
	void file()
	{
		#ifdef NTFOrz
		freopen("a.in","r",stdin);
		#endif
	}
	inline void chktime()
	{
		#ifndef ONLINE_JUDGE
		cout<<(clock()-t)/1000.0<<'\n';
		#endif
	}
	#ifdef mod
	ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x%mod) if (y&1) ret=ret*x%mod;return ret;}
	ll inv(ll x){return ksm(x,mod-2);}
	#else
	ll ksm(ll x,int y){ll ret=1;for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;return ret;}
	#endif
//	inline ll mul(ll a,ll b){ll d=(ll)(a*(double)b/mod+0.5);ll ret=a*b-d*mod;if (ret<0) ret+=mod;return ret;}
}
using namespace my_std;

int n;
db a[sz];
int fail[sz],cnt;
vector<db>R[sz];
db delta[sz];

int main()
{
	file();
	read(n);
	rep(i,1,n) read(a[i]);
	int bst=0;
	rep(i,1,n)
	{
		db cur=a[i];
		rep(j,0,(int)R[cnt].size()-1) cur-=R[cnt][j]*a[i-j-1];
		if (fabs(cur)<1e-7) continue;
		delta[cnt]=cur;fail[cnt]=i;++cnt;
		if (cnt==1){R[cnt].resize(i);continue;}
		db mul=delta[cnt-1]/delta[bst];
		vector<db>tmp; 
		tmp.resize(i-fail[bst]-1);tmp.push_back(mul);
		rep(j,0,(int)R[bst].size()-1) tmp.push_back(-R[bst][j]*mul);
		R[cnt]=tmp;if (R[cnt-1].size()>tmp.size()) R[cnt].resize(R[cnt-1].size());
		rep(j,0,(int)R[cnt-1].size()-1) R[cnt][j]+=R[cnt-1][j];
		if (i-fail[bst]+R[bst].size()>R[cnt-1].size()) bst=cnt-1;
	}
	rep(i,0,(int)R[cnt].size()-1) printf("%.5lf ",R[cnt][i]);
	return 0;
}
posted @ 2019-05-10 14:09  p_b_p_b  阅读(874)  评论(0编辑  收藏  举报