【学习笔记】BM(Berlekamp-Massey)算法
就是那个流氓一样的构造递推式的算法= =
所以做一些恐怖多项式题的时候可以试试 BM+肉眼观察x
BM算法通过增量法构造 每次考虑拟合下一个点
设我们现在的递推式为\(R_c\) 初始时\(R_0\)为空
考虑在数列末尾加入\(a_i\) 当前\(|R_c|=m\)
设\(\Delta_i = a_i - \sum_{j=1}^m a_{i-j} R_{c,j}\)
如果\(\Delta_i=0\)则直接拟合到这个点了 直接看下一个点
否则,设当前\(Fail_c\)的位置为\(i\) 即\(Fail_c=i\)
考虑补充一个\(R_{\Delta}\)使得\(R_{c+1}=R_c+R_\Delta\)可以拟合到当前的点
\(R_\Delta\)需要满足
- \(\forall |R_\Delta|+1\le k <i, \sum_{j=1}^{|R_\Delta|} a_{k-j}R_{\Delta j} = 0\)
- \(\sum_{j=1}^{|R_{\Delta}|} a_{i-j} R_{\Delta j}=\Delta_i\)
我们再看一下\(R_{c-1}\)都满足了哪些条件
1.\(\forall |R_{c-1}+1|\le k <Fail_{c-1}, a_k -\sum_{j=1}^{|R_{c-1}|} a_{k-j}R_{c-1, j} = 0\)
2.\(a_{Fail_{c-1}} -\sum_{j=1}^{R_{c-1}} a_{Fail_{c-1}-j} R_{c-1,j}=\Delta_{Fail_{c-1}}\)
发现这两个柿子其实很像
我们给出如下构造
令\(t=\frac{\Delta_i}{\Delta_{Fail_{c-1}}}\)
\(R_\Delta = \{0,0,\dots,0,t,-t\cdot R_{c-1,1} , -t\cdot R_{c-1,2},\dots\}\)
其中开头是\(i-Fail_{c-1} -1\)个0,\(t\)后面跟的是对应的\(|R_{c-1}|\)个值
考虑证明正确性
- \(\forall |R_\Delta|+1\le k <i\) 贡献是\(t\cdot (a_k -\sum_{j=1}^{|R_{c-1}|} a_{k-j}R_{c-1, j})=t\cdot 0 =0\) 细节和底下的2类似 可以先看下面的x
- \(t\)是第\(i-Fail_{c-1}\)项 对于第\(i\)项的贡献是 \(t\cdot a_{Fail_{c-1}}\) 考虑后面的\(|R_{c-1}|\)个值 对于答案的贡献是\(\sum_{j=1}^{|R_{c-1}|} -t \cdot R_{c-1,j} a_{Fail_{c-1}-j}\) 总贡献就是\(t \cdot (a_{Fail_{c-1}} -\sum_{j=1}^{R_{c-1}} a_{Fail_{c-1}-j} R_{c-1,j})=t\cdot \Delta_{Fail_{c-1}}=\Delta_i\) (转化见上面满足的条件)
所以\(R_\Delta\)符合要求 至于求最短的递推式呢 我们再额外枚举一个\(id\) 找到\(i-Fail_{id}+|R_{id}|\)最短的就可以了
真的流氓算法x
代码容我咕咕一下(。
昨天上午学完这个算法就病倒了...我有理由怀疑是算法的问题x
代码是洛谷的板子w
//Love and Freedom.
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<vector>
#define ll long long
#define inf 20021225
#define mdn 998244353
#define N 30100
#define G 3
using namespace std;
int read()
{
int s=0,t=1; char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-') t=-1; ch=getchar();}
while(ch>='0' && ch<='9') s=s*10+ch-'0',ch=getchar();
return s*t;
}
void upd(int &x,int y){x+=x+y>=mdn?y-mdn:y;}
int ksm(int bs,int mi)
{
int ans=1;
while(mi)
{
if(mi&1) ans=1ll*ans*bs%mdn;
bs=1ll*bs*bs%mdn,mi>>=1;
}
return ans;
}
int inv,r[N*4],n,k;
void ntt(int *a,int lim,int l,int f)
{
for(int i=0;i<lim;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
inv=ksm(lim,mdn-2);
for(int i=0;i<lim;i++) if(r[i]>i)
swap(a[r[i]],a[i]);
for(int k=2,mid=1;k<=lim;k<<=1,mid<<=1)
{
int Wn=ksm(G,(mdn-1)/k);
if(f) Wn=ksm(Wn,mdn-2);
for(int i=0,w=1;i<lim;i+=k,w=1)
for(int j=0;j<mid;j++,w=1ll*w*Wn%mdn)
{
int x=a[i+j],y=1ll*w*a[i+j+mid]%mdn;
a[i+j]=(x+y)%mdn; a[i+j+mid]=(x-y+mdn)%mdn;
}
}
if(f) for(int i=0;i<lim;i++) a[i]=1ll*a[i]*inv%mdn;
}
int f[N*4],g[N*4],h[N*4];
void poly_inv(int *a,int *g,int n)
{
if(n==1){g[0]=ksm(a[0],mdn-2); return;}
int mid=(n+1)>>1; poly_inv(a,g,mid);
int lim=1,l=0;
while(lim<(n<<1)) lim<<=1,l++;
for(int i=0;i<n;i++) h[i]=a[i];
for(int i=n;i<lim;i++) h[i]=0;
ntt(h,lim,l,0); ntt(g,lim,l,0);
for(int i=0;i<lim;i++)
g[i]=1ll*(mdn+2-1ll*h[i]*g[i]%mdn+mdn)%mdn*g[i]%mdn;
ntt(g,lim,l,1);
for(int i=n;i<lim;i++) g[i]=0;
}
int ff[N*4],fd[N*4],rm[N*4],q[N*4],rf[N*4],irg[N*4];
int st[N],xs[N],sg[N*4],ret[N*4],bs[N*4],a[N*4];
void poly_mod(int *a,int lim,int l)
{
int mi=(k<<1); while(a[--mi]==0); if(mi<k)return;
for(int i=0;i<lim;i++) rf[i]=0;
for(int i=0;i<=mi;i++) rf[i]=a[i];
reverse(rf,rf+mi+1);
for(int i=mi-k+1;i<=mi;i++) rf[i]=0;
ntt(rf,lim,l,0);
for(int i=0;i<lim;i++) q[i]=1ll*rf[i]*irg[i]%mdn;
ntt(q,lim,l,1);
for(int i=mi-k+1;i<=lim;i++) q[i]=0;
reverse(q,q+mi-k+1);ntt(q,lim,l,0);
for(int i=0;i<lim;i++) q[i]=1ll*q[i]*sg[i]%mdn;
ntt(q,lim,l,1);
for(int i=0;i<k;i++) a[i]=(a[i]+mdn-q[i])%mdn;
for(int i=k;i<=mi;i++) a[i]=0;
}
vector<int> coe[N]; int delta[N],fail[N],bas[N],tot;
void solve(int len,int *a,int *res)
{
int cur=0;
for(int i=1;i<=len;i++)
{
int tmp=a[i];
for(int j=0;j<coe[cur].size();j++)
upd(tmp,mdn-1ll*coe[cur][j]*a[i-j-1]%mdn);
delta[i]=tmp; if(!tmp) continue;
fail[cur]=i;
if(!cur){coe[++cur].resize(i); delta[i]=a[i]; continue;}
int id=cur-1,nlen=coe[id].size()-fail[id]+i;
for(int j=0;j<cur;j++)
if(i+coe[j].size()-fail[j]<nlen) nlen=i+coe[j].size()-fail[j],id=j;
coe[cur+1]=coe[cur],cur++;
while(coe[cur].size()<nlen) coe[cur].push_back(0);
int t=1ll*delta[i]*ksm(delta[fail[id]],mdn-2)%mdn;
upd(coe[cur][i-fail[id]-1],t);
for(int j=0;j<coe[id].size();j++)
upd(coe[cur][i-fail[id]+j],mdn-1ll*t*coe[id][j]%mdn);
}
tot=coe[cur].size();
for(int i=0;i<coe[cur].size();i++)
bas[i+1]=coe[cur][i];
}
int w[N],remd[N],cur[N];
int main()
{
int n=read(),m=read();
for(int i=1;i<=n;i++) st[i-1]=a[i]=read();
solve(n,a,bas);
for(int i=1;i<=tot;i++) printf("%d ",bas[i]);
for(int i=1;i<=tot;i++) sg[tot-i]=(mdn-bas[i])%mdn; sg[tot]=1;
int l=0,lim=1; k=tot;
while(lim<=k) lim<<=1,l++;
for(int i=0;i<=tot;i++) ret[i]=sg[i];
for(int i=0;i<=tot;i++) rf[i]=sg[i];
reverse(rf,rf+tot+1); poly_inv(rf,irg,lim);
for(int i=0;i<=tot;i++) rf[i]=0; lim<<=1,l++;
memset(a,0,sizeof(a));
ntt(sg,lim,l,0); ntt(irg,lim,l,0); a[1]=1; bs[0]=1;
while(m)
{
if(m&1)
{
ntt(bs,lim,l,0); ntt(a,lim,l,0);
for(int i=0;i<lim;i++) bs[i]=1ll*bs[i]*a[i]%mdn;
ntt(bs,lim,l,1); ntt(a,lim,l,1); poly_mod(bs,lim,l);
}
ntt(a,lim,l,0);
for(int i=0;i<lim;i++) a[i]=1ll*a[i]*a[i]%mdn;
ntt(a,lim,l,1);
poly_mod(a,lim,l);
m>>=1;
}
int ans=0;
for(int i=0;i<tot;i++) ans=1ll*(ans+1ll*bs[i]*st[i]%mdn)%mdn;
printf("\n%d\n",ans);
return 0;
}