多项式取模优化线性递推总结
多项式取模优化线性递推总结
声明:博主已退役,这是以前的总结,如有错误望指正,如有问题
不妨看看别人的博客
线性递推
即对于数列\(\{a\}\)
已知前\(k\)项
且对于任意\(n\ge k\)有
\[a_n=\sum_{i=0}^{k-1}f_ia_{n-1-i}
\]
其中\(\{f\}\)是一个已知的数列
现在要求\(\{a\}\)的第\(n\)项
暴力是\(O(n*k)\)的
如果\(n\)太大就会超时
常用的优化方法是矩阵快速幂
复杂度\(O(k^3\log n)\)
但如果\(k\)比较大也会超时
甚至还不如暴力
\(\log n\)已经很优秀了
但是\(k^3\)实在太慢
注意到根据上面的式子
\(\{a\}\)所有数都可以被\(\{a_0,a_1,...,a_{k-1}\}\)线性表示
考虑已知\(a_n\)的线性表示如何求出\(a_{2n}\)的线性表示
这里应用一个性质
若
\[a_{n}=\sum_{i=0}^{k-1}b_ia_i\\
\]
则
\[a_{n+x}=\sum_{i=0}^{k-1}b_ia_{i+x}
\]
所以
\[a_{2n}=\sum_{i=0}^{k-1}b_ia_{n+i}\\
=\sum_{i=0}^{k-1}b_i\sum_{j=0}^{k-1}b_ja_{i+j}\\
=\sum_{i=0}^{2k-2}a_i\sum_{j=0}^{i}b_jb_{i-j}\\
(这里令b_x=0(x\ge k))
\]
这样就用\(\{a_0,a_1,...,a_{2k-2}\}\)线性表示了\(a_{2n}\)
只要知道\(\{a_k,a_{k+1},...,a_{2k-2}\}\)的线性表示然后带入即可
这一步倒着依次带入
复杂度优化为\(O(k^2\log n)\)
#include<bits/stdc++.h>
using namespace std;
#define gc c=getchar()
#define r(x) read(x)
#define ll long long
template<typename T>
inline void read(T&x){
x=0;T k=1;char gc;
while(!isdigit(c)){if(c=='-')k=-1;gc;}
while(isdigit(c)){x=x*10+c-'0';gc;}x*=k;
}
const int p=1000000007;
const int N=2000;
inline int add(int a,int b){
a+=b;
if(a>=p)a-=p;
return a;
}
int n,k;
int Tmp[N<<1];
inline void mul(int* a,int *b,int* f){
memset(Tmp,0,k<<3);
for(int i=0;i<k;++i){
for(int j=0;j<k;++j){
Tmp[i+j]=add(Tmp[i+j],(ll)a[i]*b[j]%p);
}
}
for(int i=(k<<1)-2;i>=k;--i){
for(int j=0;j<k;++j){
Tmp[i-j-1]=add(Tmp[i-j-1],(ll)Tmp[i]*f[j]%p);
}
}
memcpy(a,Tmp,k<<2);
}
int base[N],ans[N];
inline int solve(int* a,int* f,int n){
if(n<k)return a[n];
base[1]=ans[0]=1;
for(;n;n>>=1){
if(n&1)mul(ans,base,f);
mul(base,base,f);
}
int ret=0;
for(int i=0;i<k;++i)ret=add(ret,(ll)a[i]*ans[i]%p);
return ret;
}
int a[N],f[N];
int main(){
r(n);r(k);
for(int i=0;i<k;++i)r(f[i]),f[i]=add(f[i],p);
for(int i=0;i<k;++i)r(a[i]),a[i]=add(a[i],p);
printf("%d\n",solve(a,f,n));
}
多项式取模在哪里?
考虑上面代码的这一部分
for(int i=(k<<1)-2;i>=k;--i){
for(int j=0;j<k;++j){
Tmp[i-j-1]=add(Tmp[i-j-1],(ll)Tmp[i]*f[j]%p);
}
}
考虑消去第\(n\)位的时候
相当于把多项式\(\{-f_{k-1},-f_{k-2},...,-f_0,1\}\)平移了\(n-k\)位
并从原数列中减去它的\(Tmp_n\)倍
所以这段代码实际上是在对多项式\(\{-f_{k-1},-f_{k-2},...,-f_0,1\}\)取模
于是复杂度可以优化至\(O(k\log k \log n)\)
#include<bits/stdc++.h>
using namespace std;
#define gc c=getchar()
#define r(x) read(x)
#define ll long long
template<typename T>
inline void read(T&x){
x=0;T k=1;char gc;
while(!isdigit(c)){if(c=='-')k=-1;gc;}
while(isdigit(c)){x=x*10+c-'0';gc;}x*=k;
}
const int N=500000;
const int p=998244353;
const int g=3;
inline int qpow(int a,int b){
int ans=1;
for(;b;b>>=1){
if(b&1)ans=1ll*ans*a%p;
a=1ll*a*a%p;
}
return ans;
}
namespace polynomial{
int r[N];
int NOW_LEN;
inline void ntt(int *A,int len,int opt=1){
if(len!=NOW_LEN)for(int i=0;i<len;++i)r[i]=(r[i>>1]>>1)|((i&1)*(len>>1));
NOW_LEN=len;
for(int i=0;i<len;++i)if(i<r[i])swap(A[i],A[r[i]]);
for(int i=2;i<=len;i<<=1){
int wn=qpow(g,(p-1)/i),n=i>>1;
if(!opt)wn=qpow(wn,p-2);
for(int j=0;j<len;j+=i){
int w=1;
for(int k=0;k<n;++k,w=1ll*w*wn%p){
int u=A[j+k],v=1ll*A[j+k+n]*w%p;
A[j+k]=(u+v)%p;
A[j+k+n]=(u-v+p)%p;
}
}
}
if(!opt){
int inv=qpow(len,p-2);
for(int i=0;i<len;++i)A[i]=1ll*A[i]*inv%p;
}
}
int Tmp_mul1[N],Tmp_mul2[N];
inline void mul(int *A,int *B,int *C,int lenA,int lenB){
int len=1,lenC=lenA+lenB-1;
while(len<lenC)len<<=1;
memcpy(Tmp_mul1,A,lenA<<2);
memcpy(Tmp_mul2,B,lenB<<2);
memset(Tmp_mul1+lenA,0,(len-lenA)<<2);
memset(Tmp_mul2+lenB,0,(len-lenB)<<2);
ntt(Tmp_mul1,len);ntt(Tmp_mul2,len);
for(int i=0;i<len;++i)C[i]=1ll*Tmp_mul1[i]*Tmp_mul2[i]%p;
ntt(C,len,0);
memset(C+lenC,0,(len-lenC)<<2);
}
int Tmp_inv[N];
inline void inverse(int *A,int *Inv,int len){
memset(Inv,0,len<<2);
Inv[0]=qpow(A[0],p-2);
for(int i=2;i<=len;i<<=1){
memcpy(Tmp_inv,A,i<<2);
memset(Tmp_inv+i,0,i<<2);
ntt(Inv,i<<1);ntt(Tmp_inv,i<<1);
for(int k=0;k<i<<1;++k)Inv[k]=Inv[k]*(2-1ll*Inv[k]*Tmp_inv[k]%p+p)%p;
ntt(Inv,i<<1,0);
memset(Inv+i,0,i<<2);
}
}
int A0[N],B0[N];
inline void mod(int A[],int B[],int R[],int lenA,int lenB){
int len=1,t=lenA-lenB+1;
while(len<=t)len<<=1;
reverse_copy(B,B+lenB,A0);
inverse(A0,B0,len);
reverse_copy(A,A+lenA,A0);
mul(A0,B0,A0,t,t);
reverse(A0,A0+t);
for(len=1;len<(lenA<<1);len<<=1);
copy(B,B+lenB,B0);
mul(A0,B0,R,t,lenB);
for(int i=0;i<lenB-1;++i)R[i]=(A[i]-R[i]+p)%p;
}
}
int n,k;
int Tmp[N<<1];
inline void mul(int a[],int b[],int f[]){
polynomial::mul(a,b,Tmp,k,k);
polynomial::mod(Tmp,f,a,2*k,k+1);
}
int base[N],ans[N];
inline int solve(int a[],int f[],int n){
if(n<k)return a[n];
reverse(f,f+k);
for(int i=0;i<k;++i)f[i]=p-f[i];
f[k]=1;
base[1]=ans[0]=1;
for(;n;n>>=1){
if(n&1)mul(ans,base,f);
mul(base,base,f);
}
int ret=0;
for(int i=0;i<k;++i)ret=(ret+(ll)a[i]*ans[i]%p)%p;
return ret;
}
int a[N],f[N];
int main(){
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
r(n);r(k);
for(int i=0;i<k;++i)r(f[i]),f[i]=(f[i]+p)%p;
for(int i=0;i<k;++i)r(a[i]),a[i]=(a[i]+p)%p;
printf("%d\n",solve(a,f,n));
}