BM算法学习笔记

这似乎是一个很冷门的算法。


啥是BM?

Berlekamp-Massey算法是一种算法,可以找到给定二进制输出序列的最短线性反馈移位寄存器(LFSR)。
该算法还将在任意场中找到线性递归序列的最小多项式。
字段要求意味着Berlekamp-Massey算法要求所有非零元素都具有乘法逆。 Reeds和Sloane提供延伸处理戒指。

以上摘自百度百科,看上去非常的离谱是机翻?,我来把重点划一下。

在任意场中找到线性递归序列的最小多项式。

意思是说给出任意一个数列,都可以找出它的最小线性递推式(在有意义的情况下)。

大概是形如 \( \begin{aligned} a_{i}&=\sum^{m}_{j=1} r_{j}a_{i-j} \end{aligned} \)这样的数列,那么就称数列 \(r\) 为数列 \(a\) 的线性递推式。

而其中最短的那个就称之为最短线性递推式。

使用BM算法可以在 \(O(n^2)\) 的时间内得到 \(r\)

BM有啥用?

  • 虽然是个冷门的算法,一般题目里也用不到,但是在某些情况下可以进一步发掘题目性质。指打表用BM找规律
  • 在配合线性递推的情况下可以在想不出式子的时候艹过一些数数题。

BM该咋写?

构造一个正确的线性递推式

首先,如果当前的数列第一次出现了非 \(0\) 的数,设其为 \(i\) ,那么此时的递推式一定是一个长为 \(i\) 的全为 \(0\) 的序列。

假设前面已经得到了 \(R_{n-1}\) ,它对于一直到 \(a_{n-1}\) 的地方都是正确的,但是在 \(a_{n}\) 的时候挂了。

设最后用 \(R_{n-1}\) 推出来的答案 \(a'_{n}\)\(a_{n}\) 的差是 \(d_{n}\)

考虑使用增量法,找到一个 \(R'\),使得 \(R_{n}=R_{n-1}+R'\)

那么 $ R'$ ,需要满足 \( \left \{ \begin{aligned} &\sum\limits^{m'}_{j=1}r'_{j}a_{i-j}=0,i<n\\ &\sum\limits^{m'}_{j=1}r'_{j}a_{n-j}=d_{n} \end{aligned} \right . \)

我们考虑到对于前面的一个 \(G\) ,它是在 \(n\in[1,p)\) 的时候成立的最短递推式,其长度为 \(m\)

那么有:

\[\begin{aligned} \sum^{m}_{j=1}g_{j}a_{i-j}&=a_{i},i<p\\ \sum^{m}_{j=1}g_{j}a_{p-j}&=a_{p}-d_{p} \end{aligned} \]

发现这东西和 \(R'\) 长的非常像啊,是不是可以通过 \(G\) 构造出 \(R'\) 呢?

\(\delta(G)\) 表示 \(G\) 向右平移了 \(n-p\) 个单位,相当于在前面填上了 \(n-p\)\(0\)

接下来,只要在 \(\delta(G)\) 的每一个位置乘上一个 \(-\frac{d_{n}}{d_{p}}\) ,就可以让第 \(n\) 项的卷积变成 \(-\frac{d_n}{d_p}\times (a_{p}-d_{p})=d_{n}-\frac{d_{n}}{d_{p}}a_{p}\)

而后面多出来的那个 \(-\frac{d_{n}}{d_{p}}a_{p}\) ,可以直接通过在 \(\delta(G)_{n-p}\) 的位置上直接填一个 \(\frac{d_{n}}{d_{p}}\) 来抵消。

这样已经完成了第二个限制了,但是这么做不会破坏第一个限制吗?

可以发现,其它位置 \(i\) ,如果不考虑在 \(n-p\) 位置上的数,那么有 \(\sum\limits^{m}_{j=1}\delta g_{j+n-p}a_{i-j}=-\frac{d_{n}}{d_{p}}a_{i-n+p}\),它恰好和 \(\delta g_{n-p}a_{i-n+p}\) 抵消掉了!真神奇啊

因此,这样构造是正确的。

构造一个最短的线性递推式

可以发现,上面的做法对于任意的递推式 \(G\) ,都是适用的,那么怎么样才可以得到最短的递推式 \(R\) 呢?

其实这一步也非常简单,根据上面的推导过程可以发现上面得到的 \(R'\) 的长度是 \(len(G)+n-p\),显然只要取令\(len(G)-p\) 最小的递推式 \(G\) 就好了。

代码实现

浮点数的情况下
#include<cstdio>
#include<vector>
#include<cmath>
using namespace std;
#define il inline
#define ri register int
#define ll long long
#define ui unsigned int
il ll read(){
    bool f=true;ll x=0;
    register char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=false;ch=getchar();}
    while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    if(f) return x;
    return ~(--x);
}
il void write(const ll &x){if(x>9) write(x/10);putchar(x%10+'0');}
il void print(const ll &x) {x<0?putchar('-'),write(~(x-1)):write(x);putchar('\n');}
il ll max(const ll &a,const ll &b){return a>b?a:b;}
il ll min(const ll &a,const ll &b){return a<b?a:b;}
const int MAXN=3e3+7;
vector<double> f[MAXN],r;
double d[MAXN],a[MAXN];
int n,len[MAXN],fail[MAXN],cnt,m;
const double eps=1e-8;
int main(){
    n=read();
    for(ri i=1;i<=n;++i) a[i]=read();
    for(ri i=1;i<=n;++i){
        d[i]=a[i];
        for(ri j=0;j<f[cnt].size();++j) 
            d[i]-=f[cnt][j]*a[i-j-1];
        if(fabs(d[i])<=eps) continue;
        fail[cnt]=i;
        if(!cnt){
            f[++cnt].resize(i,0);
            continue;
        }
        int id=0,w=1e7;
        for(ri j=0;j<cnt;++j)
            if(i+f[j].size()-fail[j]<w)
                w=i+f[j].size()-fail[j],id=j;
        double x=d[i]/d[fail[id]];
        f[cnt+1]=f[cnt];
        cnt++;
        if(w>f[cnt].size()) f[cnt].resize(w);
        f[cnt][i-fail[id]-1]+=x;
        vector<double> &g=f[id];
        for(ri j=0;j<g.size();++j) 
            f[cnt][j+i-fail[id]]-=g[j]*x;
    }
    for(ri i=0;i<f[cnt].size();++i) 
        printf("%.0lf ",f[cnt][i]);
    return 0;
}

能过版题的代码(强行缝了一个线性递推上去)
#include<bits/stdc++.h>
using namespace std;
#define il inline
#define ri register int
#define ll long long
#define ui unsigned int
il ll read(){
    bool f=true;ll x=0;
    register char ch=getchar();
    while(ch<'0'||ch>'9') {if(ch=='-') f=false;ch=getchar();}
    while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    if(f) return x;
    return ~(--x);
}
il void write(const ll &x){if(x>9) write(x/10);putchar(x%10+'0');}
il void print(const ll &x) {x<0?putchar('-'),write(~(x-1)):write(x);putchar('\n');}
il ll max(const ll &a,const ll &b){return a>b?a:b;}
il ll min(const ll &a,const ll &b){return a<b?a:b;}
const ll mod=998244353;
il ll dec(ll x,ll y) {return x>=y?x-y:x-y+mod;}
il ll add(ll x,ll y) {return x+y<mod?x+y:x+y-mod;}
ll ksm(ll d,ll t){
    ll res=1;
    for(;t;t>>=1){
        if(t&1) res=res*d%mod;
        d=d*d%mod;
    }
    return res;
}
namespace GetP{
    const int MAXN=1e4+7;
    vector<ll> f[MAXN];
    ll d[MAXN],fail[MAXN],cnt,lst;
    vector<ll> main(int n,ll *a){
        for(ri i=0;i<n;++i){
            d[i]=a[i];
            for(ri j=0;j<f[cnt].size();++j)
                d[i]=dec(d[i],f[cnt][j]*a[i-j-1]%mod);
            if(!d[i]) continue;
            fail[cnt]=i;
            if(!cnt){
                f[++cnt].resize(i+1);
                continue;
            }
            int id=cnt-1,w=i+f[id].size()-fail[id];
            for(ri j=0;j<cnt;++j)
                if(i+f[j].size()-fail[j]<w)
                    id=j,w=i+f[j].size()-fail[j];
            
            ll x=d[i]*ksm(d[fail[id]],mod-2)%mod;
            f[cnt+1]=f[cnt],++cnt;
            if(f[cnt].size()<w) f[cnt].resize(w);
            int delta=i-fail[id];
            f[cnt][delta-1]=add(f[cnt][delta-1],x);
            for(ri j=0;j<f[id].size();++j)
                f[cnt][j+delta]=dec(f[cnt][j+delta],f[id][j]*x%mod);
        }
        return f[cnt];
    }
}
namespace Poly{
    const int N=4096<<3;
    vector<ll> w[23];
    ll fac[N],ifac[N],inv[N];
    il ll C(ll x,ll y){
        if(x<y) return 0;
        return fac[x]*ifac[y]%mod*ifac[x-y]%mod;
    }
    void init(){
        int n=N-1;
        fac[0]=1;
        for(ri i=1;i<=n;++i) fac[i]=i*fac[i-1]%mod;
        ifac[n]=ksm(fac[n],mod-2);
        for(ri i=n-1;~i;--i) ifac[i]=ifac[i+1]*(i+1)%mod;
        for(ri i=n;i;--i) inv[i]=ifac[i]*fac[i-1]%mod;
        inv[0]=1;
        int d=log(N)/log(2)+0.5;
        for(ri i=1;i<=d;++i){
            w[i].resize(1<<i);
            w[i][0]=1,w[i][1]=ksm(3,(mod-1)>>i);
            for(ri j=2;j<(1<<i);++j) w[i][j]=w[i][j-1]*w[i][1]%mod;
        }
    }
    int r[N];
    void DFT(int limit,ll *a,int flag){
        for(ri i=1;i<limit;++i) if(r[i]<i) swap(a[i],a[r[i]]);
        for(ri l=1,t=1;l<limit;l<<=1,++t){
            for(ri i=0;i<limit;i+=l<<1){
                ll *W=&w[t][0];
                for(ri j=0;j<l;++j){
                    ll tmp=a[i+j+l]*(*W++)%mod;
                    a[i+j+l]=dec(a[i+j],tmp);
                    a[i+j]=add(a[i+j],tmp);
                }
            }
        }
        if(flag==-1){
            reverse(a+1,a+limit);
            ll inv=ksm(limit,mod-2);
            for(ri i=0;i<limit;++i) a[i]=a[i]*inv%mod;
        }
    }
    ll _F[N];
    void rev(int limit,int ws){
        for(ri i=0;i<limit;i+=2){
            r[i]=r[i>>1]>>1;
            r[i|1]=(r[i>>1]>>1)|(1<<ws-1);
        }
    }
    void NTT(int n,ll *f,int m,ll *g,int lim=0){
        if(!lim) lim=n+m;
        int limit=1,ws=0;
        for(;limit<=n+m;++ws,limit<<=1);
        rev(limit,ws);
        for(ri i=0;i<=m;++i) _F[i]=g[i];
        for(ri i=m+1;i<limit;++i) _F[i]=0;
        for(ri i=n+1;i<limit;++i) f[i]=0;
        DFT(limit,f,1),DFT(limit,_F,1);
        for(ri i=0;i<limit;++i) f[i]=f[i]*_F[i]%mod;
        DFT(limit,f,-1);
        for(ri i=lim+1;i<limit;++i) f[i]=0;
    }
    void Inv(int n,ll *h,ll *f){
        memset(_F,0,sizeof(_F));
        f[0]=ksm(h[0],mod-2);
        for(ri t=1,l=2;1;l<<=1,++t){
            for(ri i=0;i<l;++i) _F[i]=h[i];
            rev(l<<1,t+1);
            DFT(l<<1,f,1),DFT(l<<1,_F,1);
            for(ri i=0;i<(l<<1);++i) f[i]=(2*f[i]-f[i]*f[i]%mod*_F[i]%mod+mod)%mod;
            DFT(l<<1,f,-1);
            if(l>n){
                for(ri i=n+1;i<(l<<1);++i) f[i]=0;
                break;
            }
            for(ri i=l;i<(l<<1);++i) f[i]=0;
        }
    }
    il void der(int n,ll *f){
        for(ri i=0;i<n;++i) 
            f[i]=f[i+1]*(i+1)%mod;
        f[n]=0;
    }
    il void Int(int n,ll *f){
        for(ri i=n;i;--i) 
            f[i]=f[i-1]*inv[i]%mod;
        f[0]=0;
    }
    ll _G[N];
    il void Ln(int n,ll *h,ll *f){
        memset(_G,0,sizeof(_G));
        for(ri i=0;i<=n;++i) f[i]=h[i];
        Inv(n,h,_G);
        der(n,f);
        NTT(n,f,n,_G);
        Int(n,f);
    }
    ll exp_f[N];
    il void exp(int n,ll *h,ll *f){
        f[0]=1;
        for(ri l=2,t=1;1;l<<=1,t++){
            memset(exp_f,0,sizeof(exp_f));
            Ln(l-1,f,exp_f);
            for(ri i=0;i<l;++i) exp_f[i]=(-exp_f[i]+h[i]+mod)%mod;
            exp_f[0]++;
            NTT(l>>1,f,l,exp_f);
            for(ri i=l;i<(l<<1);++i) f[i]=0;
            if(l>n){
                for(ri i=n+1;i<l;++i) f[i]=0;
                break;
            }
        }
    }
    ll ksm_f[N];
    il void polyksm(int n,ll *f,ll t){
        memset(ksm_f,0,sizeof(ksm_f));
        for(ri i=0;i<=n;++i) ksm_f[i]=f[i];
        Ln(n,ksm_f,f);
        for(ri i=0;i<=n;++i) ksm_f[i]=f[i]*t%mod,f[i]=0;
        exp(n,ksm_f,f);
    }
    il void div(int n,ll *f,int m,ll *g,ll *Q,ll *R){
        reverse(f,f+n+1),reverse(g,g+m+1);
        Inv(n-m,g,Q),NTT(n-m,Q,n-m,f,n-m);
        reverse(f,f+n+1),reverse(g,g+m+1),reverse(Q,Q+n-m+1);
        for(ri i=0;i<=m;++i) R[i]=g[i];NTT(m,R,n-m,Q);
        for(ri i=0;i<=n;++i) R[i]=dec(f[i],R[i]);
    }
    ll Mod_Q[N];
    il void Mod(int n,ll *f,int m,ll *g,ll *R){
        memset(Mod_Q,0,sizeof(Mod_Q));
        reverse(f,f+n+1),reverse(g,g+m+1);
        Inv(n-m,g,Mod_Q),NTT(n-m,Mod_Q,n-m,f,n-m);
        reverse(f,f+n+1),reverse(g,g+m+1),reverse(Mod_Q,Mod_Q+n-m+1);
        for(ri i=0;i<=m;++i) R[i]=g[i];NTT(m,R,n-m,Mod_Q);
        for(ri i=0;i<=n;++i) R[i]=dec(f[i],R[i]);
    }
}
using namespace Poly;
ll R[N],h[N];
ll f[N],a[N],p[N],g[N],ans;
void Ksm(ll t,ll n){
    f[0]=1,g[1]=1;
    int limit=1,ws=0;
    for(;limit<=(n<<1);limit<<=1,++ws);
    rev(limit,ws);
    for(;t;t>>=1){
        DFT(limit,g,1);
        if(t&1){
            DFT(limit,f,1);
            for(ri i=0;i<limit;++i) h[i]=f[i]*g[i]%mod,f[i]=R[i]=0;
            DFT(limit,h,-1);
            div(limit,h,n+1,p,R,f);
            rev(limit,ws);
        }
        for(ri i=0;i<limit;++i) h[i]=g[i]*g[i]%mod,g[i]=R[i]=0;
        DFT(limit,h,-1);
        div(limit,h,n+1,p,R,g);
        rev(limit,ws);
    }
}
ll n,k,m;
vector<ll> res;
int main(){
    init();
    n=read(),m=read();
    for(ri i=0;i<n;++i) a[i]=read();
    res=GetP::main(n,a);
    k=res.size();
    for(ri i=0;i<k;++i) printf("%lld ",res[i]);
    puts("");
    for(ri i=0;i<k;++i){
        p[k-i-1]=dec(mod,res[i]);
    }
    p[k]=1;
    Ksm(m,k-1);
    for(ri i=0;i<k;++i) ans=add(ans,a[i]*f[i]%mod);
    print(ans);
    return 0;
}

参考文献

Berlekamp_Massey 算法 (BM算法) 学习笔记- by zzd
感性理解Berlekamp-Massey算法- by Lstdo
【学习笔记】Berlekamp-Massey算法- by cz_xuyixuan

posted @ 2021-04-28 20:31  krimson  阅读(514)  评论(2编辑  收藏  举报