[模板] 常系数齐次线性递推
一、题目
注意我的写的 \(a\) 和 \(f\) 和题目里面的是反的。
二、解法
我看 \(\tt oiwiki\) 上面的讲解就秒懂了!真的讲得特别特别好!
设 \(F(\sum c_ix^i)=\sum c_if_i\),\(F(x^n)\) 就是答案。
也就是我们用生成函数第 \(i\) 项作为 \(f_i\) 的记号
由于 \(f_n=\sum_{i=1}^k f_{n-i}a_i\),所以 \(F(x^n)=F(\sum_{i=1}^k a_ix^{n-i})\)
不难发现函数里面也可以直接减的,所以:
\[F(x^n-\sum_{i=1}^ka_ix^{n-i})=F(x^{n-k}(x^k-\sum_{i=0}^{k-1}a_{k-i}x^i))
\]
设 \(G(x)=x^k-\sum_{i=0}^{k-1}a_{k-i}x^i\),那么就有 \(F(A(x)+x^nG(x))=F(A(x))+F(x^nG(x))=F(A(x))\)
也就是说如果算 \(F(x^n)\) 的话就可以直接取模 \(G(x)\),设 \(P(x)=x^n\bmod G(x)\),那么答案就是 \(F(P(x))\),\(P(x)\) 是一个 \(k-1\) 次多项式所以可以直接根据定义算,用一个快速幂套多项式取模即可,时间复杂度 \(O(k\log n\log k)\)
写出来跑到了 \(\tt luogu\) 倒数第一,很好。
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define int long long
const int M = 400005;
const int MOD = 998244353;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,k,ans,r[M],a[M],b[M],f[M],g[M];
//b表示x^n模G(x)之后的多项式
namespace poly//封装了
{
int len,A[M],B[M],c[M],d[M],e[M],rev[M];
int qkpow(int a,int b)
{
int r=1;
while(b>0)
{
if(b&1) r=r*a%MOD;
a=a*a%MOD;
b>>=1;
}
return r;
}
void NTT(int *a,int len,int op)
{
for(int i=0;i<len;i++)
{
rev[i]=(rev[i>>1]>>1)|((i&1)*(len/2));
if(i<rev[i]) swap(a[i],a[rev[i]]);
}
for(int s=2;s<=len;s<<=1)
{
int t=s/2,w=(op==1)?qkpow(3,(MOD-1)/s):qkpow(3,MOD-1-(MOD-1)/s);
for(int i=0;i<len;i+=s)
for(int j=0,x=1;j<t;j++,x=x*w%MOD)
{
int fe=a[i+j],fo=a[i+j+t];
a[i+j]=(fe+x*fo)%MOD;
a[i+j+t]=((fe-x*fo)%MOD+MOD)%MOD;
}
}
if(op==1) return ;
int inv=qkpow(len,MOD-2);
for(int i=0;i<len;i++) a[i]=a[i]*inv%MOD;
}
void work(int n,int *a,int *b)//逆元从属函数
{
len=1;while(len<2*n) len<<=1;
for(int i=0;i<len;i++) A[i]=B[i]=0;
for(int i=0;i<n;i++) A[i]=a[i];
for(int i=0;i<(n/2);i++) B[i]=b[i];
NTT(A,len,1);NTT(B,len,1);
for(int i=0;i<len;i++)
A[i]=((2*B[i]-B[i]*B[i]%MOD*A[i])%MOD+MOD)%MOD;
NTT(A,len,-1);
for(int i=0;i<n;i++) b[i]=A[i];
}
void inv(int n,int *a,int *b)//逆元存在b那里
{
b[0]=qkpow(a[0],MOD-2);
int cur=1;
while(cur<n)
{
cur<<=1;
work(cur,a,b);
}
}
void mul(int n,int *a,int *b)//多项式乘法
{
len=1;while(len<2*n) len<<=1;
for(int i=0;i<len;i++) A[i]=B[i]=0;
for(int i=0;i<n;i++) A[i]=a[i],B[i]=b[i];
NTT(A,len,1);NTT(B,len,1);
for(int i=0;i<len;i++) A[i]=A[i]*B[i]%MOD;
NTT(A,len,-1);
for(int i=0;i<2*n;i++) b[i]=A[i];
}
void mod(int n,int m,int *a,int *b)
//n次多项式a取模m次多项式b
//最后的结果是余数,存在a处
{
//翻转A
for(int i=0;i<=n;i++) d[i]=a[i];
for(int i=0;i<=n/2;i++) swap(d[i],d[n-i]);
//翻转B
for(int i=0;i<=m;i++) e[i]=b[i];
for(int i=0;i<=m/2;i++) swap(e[i],e[m-i]);
inv(n-m+1,e,c);
//清除掉无用的部分
len=1;while(len<=2*(n-m)) len<<=1;
for(int i=n-m+1;i<len;i++) d[i]=c[i]=0;
NTT(c,len,1);NTT(d,len,1);
for(int i=0;i<len;i++) c[i]=c[i]*d[i]%MOD;
NTT(c,len,-1);
for(int i=n-m+1;i<len;i++) c[i]=0;
for(int i=0;i<=(n-m)/2;i++) swap(c[i],c[n-m-i]);
//算余数
len=1;while(len<=n) len<<=1;
for(int i=0;i<=m/2;i++) swap(e[i],e[m-i]);//翻转回来
NTT(c,len,1);NTT(e,len,1);
for(int i=0;i<len;i++) c[i]=c[i]*e[i]%MOD;
NTT(c,len,-1);
for(int i=0;i<=n;i++)
a[i]=((a[i]-c[i])+MOD)%MOD;
for(int i=0;i<len;i++) e[i]=c[i]=0;//用了就清
}
};
signed main()
{
n=read();k=read();
for(int i=1;i<=k;i++)
b[i]=read();
for(int i=0;i<k;i++)
g[i]=(MOD-b[k-i])%MOD;
g[k]=1;
for(int i=0;i<k;i++)
f[i]=read();
a[1]=1;//初始值是x
r[0]=1;//初始值是1
while(n>0)
{
if(n&1)
{
poly::mul(k,a,r);
poly::mod(2*k-2,k,r,g);
}
poly::mul(k,a,a);
poly::mod(2*k-2,k,a,g);
n>>=1;
}
for(int i=0;i<k;i++)
ans=(ans+f[i]*r[i])%MOD;
printf("%lld\n",(ans+MOD)%MOD);
}