BM算法学习笔记
一种nb算法,可以求出数列的递推式。
具体过程是这样的。
我们先假设它有一个递推式,然后按位去算他的值。
for(int j=0;j<now.size();++j)(delta[i]+=1ll*now[j]*f[i-j-1]%mod)%=mod;
这是我们算出了f[i]应当是多少,但是f[i]有可能不是我们算出的值,所以我们记录一个delta,为我们算出的值减去f[i]的结果。
然后查看一下之前有没有出过锅。
如果出过,那么就补一个0,然后塞过去。。
if(!cnt){now.resize(i);cnt++;continue;}
cnt记录出锅次数,now记录当前递推式。
然后我们需要构造一个递推式把这一位的delta补上。
然后我们设inv为这一次的dalta除以上一次的delta。
然后我们的递推式就是在last和now之间补0,然后加一个inv,后面把所有的pre*(-inv)加进去,这样最后n这个位置会出现-delta就满足我们的要求了。
最后我们把构造递推式和当前递推式加起来。
再贪心选一个左端点最靠右的出锅递推式作为last。
正确性???
代码
#include<iostream> #include<cstdio> #include<vector> #define N 100009 using namespace std; typedef long long ll; const ll mod=65521; ll n,f[N],delta[N],fail[N],cnt,last; vector<ll>cur,now,pre; inline int rd(){ int x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } inline ll power(ll x,ll y){ ll ans=1; while(y){ if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1; } return ans; } int main(){ n=rd(); for(int i=1;i<=n;++i)f[i]=rd(); for(int i=1;i<=n;++i){ delta[i]=mod-f[i]; for(int j=0;j<now.size();++j)(delta[i]+=1ll*now[j]*f[i-j-1]%mod)%=mod; if(!delta[i])continue; fail[cnt]=i; if(!cnt){now.resize(i);cnt++;continue;} ll inv=((mod-1ll*delta[i]*power(delta[fail[last]],mod-2)%mod)%mod+mod)%mod; cur.clear();cur.resize(i-fail[last]-1);cur.push_back(mod-inv); for(int j=0;j<pre.size();++j)cur.push_back(1ll*pre[j]*inv%mod); if(now.size()>cur.size())cur.resize(now.size()); for(int j=0;j<now.size();++j)(cur[j]+=now[j])%=mod; if(i-now.size()>=fail[last]-pre.size())pre=now,last=cnt; //fail[last]!!! cnt++;now=cur; } for(int i=0;i<now.size();++i)cout<<now[i]<<",";cout<<now.size(); return 0; }
应用:
[NOI2007]生成树计数
正解貌似是插头dp+快速幂。
然后我们发现k非常小。。。。
那么就可以对于每一个k打一个表,然后再扔到BM里跑一下,发现转移式子最大只有45。
于是就直接上矩乘。
代码
打表
#include<iostream> #include<cstdio> #include<cstring> #define N 402 using namespace std; typedef long long ll; ll a[N][N],n; const int mod=65521; inline ll power(ll x,ll y){ ll ans=1; while(y){ if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1; } return ans; } inline ll ni(ll x){return power(x,mod-2);} inline ll matr(int n){ for(int i=1;i<=n;++i){ for(int j=i+1;j<=n;++j){ ll x=1ll*a[j][i]*ni(a[i][i])%mod; for(int k=i;k<=n;++k)a[j][k]=(a[j][k]-x*a[i][k]%mod+mod)%mod; } } ll ans=1; for(int i=1;i<=n;++i)ans=ans*a[i][i]%mod; return ans; } int main(){ freopen("out","w",stdout); int kk=2; for(int n=1;n<=45;++n){ memset(a,0,sizeof(a)); for(int i=1;i<=n;++i){ for(int k=i-1;k>=1&&k>=i-kk;--k)a[i][i]++,a[i][k]--; for(int k=i+1;k<=n&&k<=i+kk;++k)a[i][i]++,a[i][k]--; } printf("%lld,",matr(n-1)); } return 0; }
矩阵乘法
#include<iostream> #include<cstdio> #include<cstring> using namespace std; typedef long long ll; const int mod=65521; ll top,n; int s1[2]={0,1}; int s2[4]={0,3,65520,0}; int s3[8]={0,5,65518,3,65516,1,0,0}; int s4[18]={0,7,65520,65496,31,65469,65437,300,65437,65469,31,65496,65520,7,65520,0,0,0}; int s5[46]={0,8,5,65489,40,364,63172,62845,2793,7304,50170,14272,13974, 32712,27590,63226,30516,31431,62449,44809,2992,62529,20712,3072,34090,35005,2295, 37931,32809,51547,51249,15351,58217,62728,2676,2349,65157,65481,32,65516,65513,1,0,0,0,0}; int a1[2]={0,1}; int a2[4]={0,1,1,3}; int a3[8]={0,1,1,3,16,75,336,1488}; int a4[18]={0,1,1,3,16,125,864,5635,35840,29517,48795,64376,52310,4486,28336,8758,64387,31184}; int a5[46]={0,1,1,3,16,125,1296,12005,38927,26915,65410,9167,63054,58705,18773,9079,38064,46824, 48121,50048,47533,30210,24390,51276,45393,357,44927,15398,15923,31582,56586,25233,41258,21255, 21563,16387,39423,26418,10008,6962,42377,50881,54893,50452,23715,53140}; inline ll power(ll x,ll y){ ll ans=1; while(y){if(y&1)ans=ans*x%mod;x=x*x%mod;y>>=1;} return ans; } struct matrix{ ll a[48][48]; matrix(){memset(a,0,sizeof(a));} matrix operator *(const matrix &b)const{ matrix c; for(int i=1;i<=top;++i) for(int j=1;j<=top;++j){ for(int k=1;k<=top;++k) (c.a[i][j]+=a[i][k]*b.a[k][j]%mod)%=mod; } return c; } }ans,Z; inline void work1(){ puts("1"); } inline void work2(){ if(n<=3){printf("%d\n",a2[n]);return;} for(int i=1;i<=3;++i){ ans.a[1][i]=a2[i]; Z.a[i][3]=s2[3-i+1]; if(i!=1)Z.a[i][i-1]=1; } n-=3;top=3; while(n){ if(n&1)ans=ans*Z; Z=Z*Z; n>>=1; } printf("%lld",ans.a[1][3]); } inline void work3(){ if(n<=7){printf("%d\n",a3[n]);return;} for(int i=1;i<=7;++i){ ans.a[1][i]=a3[i]; Z.a[i][7]=s3[7-i+1]; if(i!=1)Z.a[i][i-1]=1; } n-=7;top=7; while(n){ if(n&1)ans=ans*Z; Z=Z*Z; n>>=1; } printf("%lld",ans.a[1][7]); } inline void work4(){ if(n<=17){printf("%d\n",a4[n]);return;} for(int i=1;i<=17;++i){ ans.a[1][i]=a4[i]; Z.a[i][17]=s4[17-i+1]; if(i!=1)Z.a[i][i-1]=1; } n-=17;top=17; while(n){ if(n&1)ans=ans*Z; Z=Z*Z; n>>=1; } printf("%lld",ans.a[1][17]); } inline void work5(){ if(n<=45){printf("%d\n",a5[n]);return;} for(int i=1;i<=45;++i){ ans.a[1][i]=a5[i]; Z.a[i][45]=s5[45-i+1]; if(i!=1)Z.a[i][i-1]=1; } n-=45;top=45; while(n){ if(n&1)ans=ans*Z; Z=Z*Z; n>>=1; } printf("%lld",ans.a[1][45]); } int main(){ int k; cin>>k>>n; if(k==1)work1(); else if(k==2)work2(); else if(k==3)work3(); else if(k==4)work4(); else if(k==5)work5(); return 0; }