常系数齐次线性递推

现在终于把线性递推常数减小了许多。(实际上我原先写的多项式取模才这么慢)

Fiduccia 算法

一个方法是求 \(x^n\) 对递推数列的特征多项式

\[p(x)=x^k-p_1x^{k-1}-p_2x^{k-2}-\cdots-p_k \]

取模的结果,然后逐项代入即可求得答案。这个取模可以像普通的快速幂一样,每次对 \(p(x)\) 取模。证明不会。

但是这玩意跑的非常慢,卡了半天常仍然只能做到每个点 1.7s 的水平。当然我本身大常数选手。

这玩意当然不是重点。放个代码仅供参考。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
using namespace std;
const int mod=998244353,g=3;
int n,k,wl,ans,a[150010],b[150010],r[150010],inv[150010],res[150010];
int q[150010],f[150010],c[150010],p[150010],tmp[150010];
int val[150010];
void get(int n){
    wl=1;
    while(n>=wl)wl<<=1;
    for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
int qpow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=1ll*a*ans%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return ans;
}
const int invg=qpow(g,mod-2);
void ntt(int a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){
        int wn=qpow(tp==1?g:invg,(mod-1)/(mid<<1));
        for(int j=0;j<n;j+=mid<<1){
            int w=1;
            for(int k=0;k<mid;k++,w=1ll*w*wn%mod){
                int x=a[j+k],y=1ll*w*a[j+mid+k]%mod;
                a[j+k]=(x+y)%mod;a[j+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if(tp^1){
        int inv=qpow(n,mod-2);
        for(int i=0;i<n;i++)a[i]=1ll*a[i]*inv%mod;
    }
}
void getinv(int n,int a[],int b[]){
    if(n==1){
        b[0]=qpow(a[0],mod-2);
        return;
    }
    getinv((n+1)>>1,a,b);
    get(n<<1);
    for(int i=0;i<n;i++)c[i]=a[i];
    for(int i=n;i<wl;i++)c[i]=0;
    ntt(b,wl,1);ntt(c,wl,1);
    for(int i=0;i<wl;i++)b[i]=1ll*(2-1ll*b[i]*c[i]%mod+mod)%mod*b[i]%mod;
    ntt(b,wl,-1);
    for(int i=n;i<wl;i++)b[i]=0;
}
void getdiv(int n,int m,int a[],int b[]){
    get(n<<1);
    for(int i=0;i<wl;i++)q[i]=res[i]=c[i]=0;
    for(int i=0;i<=n;i++)q[n-i]=a[i];
    ntt(q,wl,1);
    for(int i=0;i<wl;i++)q[i]=1ll*q[i]*inv[i]%mod;
    ntt(q,wl,-1);
    for(int i=n-m+1;i<wl;i++)q[i]=0;
    reverse(q,q+n-m+1);
    ntt(b,wl,1);ntt(q,wl,1);
    for(int i=0;i<wl;i++)b[i]=1ll*b[i]*q[i]%mod;
    ntt(b,wl,-1);
    for(int i=0;i<m;i++)res[i]=(a[i]-b[i]+mod)%mod;
}
int main(){
    scanf("%d%d",&n,&k);
    for(int i=1;i<=k;i++)scanf("%d",&p[i]),p[i]=(p[i]+mod)%mod;
    reverse(p,p+k+1);p[k]=1;
    for(int i=0;i<k;i++)p[i]=(mod-p[i])%mod;
    for(int i=0;i<=k;i++)tmp[i]=f[k-i]=p[i];
    getinv(k-1,f,inv);
    get(k<<2);ntt(inv,wl,1);
    for(int i=0;i<k;i++)scanf("%d",&val[i]),val[i]=(val[i]+mod)%mod;
    b[1]=a[0]=1;
    while(n){
        if(n&1){
            get(k<<1);
            ntt(a,wl,1);ntt(b,wl,1);
            for(int i=0;i<wl;i++)a[i]=1ll*a[i]*b[i]%mod;
            ntt(a,wl,-1);ntt(b,wl,-1);
            getdiv((k-1)<<1,k,a,p);
            memcpy(a,res,sizeof(a));
            memcpy(p,tmp,sizeof(p));
        }
        get(k<<1);
        ntt(b,wl,1);
        for(int i=0;i<wl;i++)b[i]=1ll*b[i]*b[i]%mod;
        ntt(b,wl,-1);
        getdiv((k-1)<<1,k,b,p);
        memcpy(b,res,sizeof(b));
        memcpy(p,tmp,sizeof(p));
        n>>=1;
    }
    for(int i=0;i<k;i++)ans=(ans+1ll*val[i]*a[i])%mod;
    printf("%d\n",ans);
    return 0;
}

LSB-first 算法

重点来了。首先这玩意只能在模数是质数的情况下使用。

它的思想在于把递推数列转为生成函数的形式:

\[F(x)=\frac{P(x)}{Q(x)} \]

其中 \(Q(x)\) 是递推式的生成函数(就是特征多项式的系数翻转),而 \(P(x)\) 是一个次数小于 \(k\) 的多项式。

一个结论是一定存在这样的一个 \(P(x)\)。证明还是不会。那么我们就可以用 \(F(x)\) 的前 \(k\) 项和 \(Q(x)\) 计算 \(P(x)\) 的前 \(k\) 项。

然后考虑如何提取 \([x^n]\frac{P(x)}{Q(x)}\)。我们使用 Bostan-Mori 算法解决这个问题。

该算法依赖于

\[\frac{P(x)}{Q(x)}=\frac{P(x)Q(-x)}{Q(x)Q(-x)} \]

这玩意分母显然只有偶数次项不是 \(0\) ,那么求逆之后也是只有偶数次项不是 \(0\),也就是分子的奇数次项和偶数次项互不影响。那假如说我们要提取的 \(x^n\)\(n\) 是奇数的话,那么分子的偶数次项就可以全清掉。反之同理。这样每次做一次多项式乘法然后对分子分母的奇/偶次项重新标号,就可以把 \(n\) 减半。同时由于多项式乘法使得次数变成 \(2k\),这样重标号之后次数就仍然是 \(k\)。复杂度和上边的一样都是 \(O(k\log k\log n)\)

这玩意非常好写而且跑的很快,我写的这份时间空间都是上边那个的四分之一,450ms。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
using namespace std;
const int mod=998244353,g=3;
int n,k,wl,ans,f[150010],r[150010];
int q[150010],p[150010];
void get(int n){
    wl=1;
    while(n>=wl)wl<<=1;
    for(int i=0;i<=wl;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(__lg(wl)-1));
}
int qpow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=1ll*a*ans%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return ans;
}
const int invg=qpow(g,mod-2);
void ntt(int a[],int n,int tp){
    for(int i=1;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int mid=1;mid<n;mid<<=1){
        int wn=qpow(tp==1?g:invg,(mod-1)/(mid<<1));
        for(int j=0;j<n;j+=mid<<1){
            int w=1;
            for(int k=0;k<mid;k++,w=1ll*w*wn%mod){
                int x=a[j+k],y=1ll*w*a[j+mid+k]%mod;
                a[j+k]=(x+y)%mod;a[j+mid+k]=(x-y+mod)%mod;
            }
        }
    }
    if(tp^1){
        int inv=qpow(n,mod-2);
        for(int i=0;i<n;i++)a[i]=1ll*a[i]*inv%mod;
    }
}
int main(){
    scanf("%d%d",&n,&k);q[0]=1;
    for(int i=1;i<=k;i++){
        int x;scanf("%d",&x);x=(x+mod)%mod;
        q[i]=(mod-x)%mod;
    }
    for(int i=0;i<k;i++){
        int x;scanf("%d",&x);p[i]=(x+mod)%mod;
    }
    for(int i=0;i<=k;i++)f[i]=q[i];
    get(k<<1);
    ntt(p,wl,1);ntt(f,wl,1);
    for(int i=0;i<wl;i++)p[i]=1ll*p[i]*f[i]%mod;
    ntt(p,wl,-1);
    for(int i=k;i<wl;i++)p[i]=0;
    while(n){
        for(int i=0;i<=k;i++)f[i]=(i&1)?(mod-q[i])%mod:q[i];
        for(int i=k+1;i<wl;i++)f[i]=0;
        ntt(f,wl,1);ntt(p,wl,1);ntt(q,wl,1);
        for(int i=0;i<wl;i++)p[i]=1ll*p[i]*f[i]%mod,q[i]=1ll*q[i]*f[i]%mod;
        ntt(p,wl,-1);ntt(q,wl,-1);
        for(int i=0;i<=k;i++)p[i]=p[(i<<1)|(n&1)],q[i]=q[i<<1];
        for(int i=k+1;i<wl;i++)p[i]=q[i]=0;
        n>>=1;
    }
    ans=1ll*p[0]*qpow(q[0],mod-2)%mod;
    printf("%d\n",ans);
    return 0;
}
posted @ 2023-02-04 08:46  gtm1514  阅读(34)  评论(0编辑  收藏  举报