Processing math: 0%

快速线性递推

线性递推的题目区域赛里还是挺多的,还是有必要学一下

 


 

~ 快速线性递推 ~

有一个n阶线性递推f,想要计算f(m),有一种常用的办法是矩阵快速幂,复杂度是O(n^3logm)

在不少情况下这已经够用了,但是如果n比较大、到了10^3级别,这就不太适用了

而存在算法能将这个复杂度压低到O(n^2logm),若加上NTT优化的话能做到O(n^2+nlognlogm),十分厉害

 

这个算法的核心是将f(m)用递推的前n项表示

即,已知f(0),...,f(n-1)和递推式f(m)=a_0f(m-1)+...+a_{n-1}f(m-n),该算法是求出系数W_0,...,W_{n-1},使得f(m)=W_0f(n-1)+...+W_{n-1}f(0)

 

看似无从下手?实际上只要大力展开就行了

根据定义,有(只是写成\sum的形式而已)

f(m)=\sum_{i=0}^{n-1}a_i f(m-1-i)

而对于每一项再次展开,即

f(m-1-i)=\sum_{j=0}^{n-1}a_j f(m-1-i-1-j)

全部代入,能得到

f(m)=\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_ia_j f(m-2-i-j)

把式子写的更好看一点,就是

f(m)=\sum_{k=0}^{2n-2}\sum_{i+j=k}a_ia_j f(m-2-k)

 

这样做之后有什么用呢?

在原本的递推式中,f(m)可以通过f(m-1),...,f(m-n)n个项表示

各项展开后,就可以通过f(m-2),...,f(m-2n)表示

事实上,我们可以再依次对f(m-i),2\leq i\leq n展开,并将系数向f(m-i-1),...,f(m-i-n)并入,最终就能把原递推式通过f(m-n-1),...,f(m-2n)n项表示

于是可以得到一个新的n阶递推式,记为f(m)=b_0f(m-n+1),...,b_{n-1}f(m-2n)

再用新递推式将各项展开,就可以通过f(m-2n-2),...,f(m-4n)表示

再用原递推式展开f(m-2n-i),2\leq i\leq n并向前合并系数,最终就能把原递推式通过f(m-3n+1),...,f(m-4n)n项表示

之后都是类似的了,不再赘述

 

有了上面的思路,就可以用类似快速幂的方法,得到f(m)=W_0f(m-(k-1)n+1),...,W_{n-1}f(m-kn)这样的展开式,其中m-kn<n

余数m-kn是我们不喜欢的,但也没有必要整体再向前推,一开始计算时算出f(0),...,f(2n-1)就够了

按照上述思路能这样实现:

复制代码
#include <cstdio>
#include <cstring>
using namespace std;

typedef long long ll;
const int MOD=1000000007;
const int N=1005;


int n,m;
int a[N];
int f[N<<1];

int tmp[N<<1];

void mul(int *y,int *x)
{
    memset(tmp,0,sizeof(tmp));
    
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
            tmp[i+j]=(tmp[i+j]+ll(y[i])*x[j])%MOD;
    
    for(int i=0;i<n-1;i++)
        for(int j=0;j<n;j++)
            tmp[i+j+1]=(tmp[i+j+1]+ll(tmp[i])*a[j])%MOD;
    
    for(int i=0;i<n;i++)
        y[i]=tmp[i+n-1]; 
}

int w[N<<1],x[N<<1];

int BM()
{
    if(m<(n<<1))
        return f[m];
    
    for(int i=0;i<n;i++)
        x[i]=a[i],w[i]=a[i];
    
    int t=(m-n)/n;
    int rem=m-n-t*n;
    
    while(t)
    {
        if(t&1)
            mul(w,x);
        mul(x,x);
        t>>=1;
    }
    
    int res=0;
    for(int i=0;i<n;i++)
        res=(res+ll(w[i])*f[rem+n-i-1])%MOD;
    return res;
}

int main()
{
    scanf("%d%d",&n,&m);
    for(int i=0;i<n;i++)
        scanf("%d",&a[i]);
    for(int i=0;i<n;i++)
        scanf("%d",&f[i]);
    for(int i=n;i<(n<<1);i++)
        for(int j=1;j<=n;j++)
            f[i]=(f[i]+ll(a[j-1])*f[n-j])%MOD;
    
    printf("%d\n",BM());
    return 0;
}
View Code
复制代码

 

想做的更快的话,一个是要写NTT,另一个是合并系数会比较困难,待补

 


 

为了这题学的:牛客ACM 882BEddy Walker 2

m\rightarrow \infty时,f(m)\rightarrow \frac{2}{k+1} (并不会证...)

从rls那里学了一个证明:

k步,期望能走的长度是1+2+...+k=\frac{k(k+1)}{2}

那么在这段距离中,每个位置被走过的概率就是\frac{k}{\frac{k(k+1)}{2}}=\frac{2}{k+1}

在其他时候,直接套上面的板子即可

牛客的玄学评测机,同一份代码能差出500ms = =

复制代码
#include <cstdio>
#include <cstring>
using namespace std;
 
typedef long long ll;
const int MOD=1000000007;
const int N=1100;
 
inline int quickpow(int x,int t)
{
    int res=1;
    while(t)
    {
        if(t&1)
            res=ll(res)*x%MOD;
        x=ll(x)*x%MOD;
        t>>=1;
    }
    return res;
}
 
inline int rev(int x)
{
    return quickpow(x,MOD-2);
}
 
int n,rn;
ll m;
int a[N];
int f[N<<1];
 
int tmp[N<<1];
 
void mul(int *y,int *x)
{
    memset(tmp,0,sizeof(tmp));
     
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
            tmp[i+j]=(tmp[i+j]+ll(y[i])*x[j])%MOD;
     
    for(int i=0;i<n-1;i++)
        for(int j=0;j<n;j++)
            tmp[i+j+1]=(tmp[i+j+1]+ll(tmp[i])*a[j])%MOD;
     
    for(int i=0;i<n;i++)
        y[i]=tmp[i+n-1];
}
 
int w[N<<1],x[N<<1];
 
int BM()
{
    if(m<(n<<1))
        return f[m];
     
    for(int i=0;i<n;i++)
        x[i]=a[i],w[i]=a[i];
     
    ll t=(m-n)/n;
    int rem=m-n-t*n;
     
    while(t)
    {
        if(t&1)
            mul(w,x);
        mul(x,x);
        t>>=1;
    }
     
    int res=0;
    for(int i=0;i<n;i++)
        res=(res+ll(w[i])*f[rem+n-i-1])%MOD;
    return res;
}
 
int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%lld",&n,&m);
         
        if(m==-1)
        {
            printf("%d\n",2LL*rev(n+1)%MOD);
            continue;
        }
         
        rn=rev(n);
        for(int i=0;i<n;i++)
            a[i]=rn;
         
        memset(f,0,sizeof(f));
        f[0]=1;
        for(int i=1;i<(n<<1);i++)
            for(int j=1;j<=n && j<=i;j++)
                f[i]=(f[i]+ll(rn)*f[i-j])%MOD;
         
        printf("%d\n",BM());
    }
     
    return 0;
}
View Code
复制代码

 


 

比较特定的知识点吧,以后遇到就是赚到(然后发现强制NTT,直接白给= =)

(完)

posted @   LiuRunky  阅读(1302)  评论(2编辑  收藏  举报
点击右上角即可分享
微信分享提示