【XSY3915】数学作业(常系数齐次线性递推,多项式,矩阵快速幂)
题面
题解
看到 \(m\) 很大,联想到矩阵快速幂。
由于对于每个初始的 \(x\),它变成 \(1\) 的方法是唯一的。所以我们可以考虑从 \(1\) 倒推,这样用不同的方法倒推得到的数肯定是不同的,所以不会算重。
为了方便,我们从 \(0\) 而不是 \(1\) 开始倒推,此时原来倒推 \(m\) 步就变成了倒推 \(m+1\) 步(这样会算多,但我们最后再处理多出来的这一部分)。
假设当前数为 \(x\),考虑 \(x\) 向前倒推:若 \(x\not \equiv k-1 \pmod k\),则 \(x\) 可以由 \(x+1\) 和 \(xk\) 转移得到;若 \(x\equiv k-1\pmod k\),则 \(x\) 只可以由 \(xk\) 转移得到。
那么设 \(f_{i,j}\) 表示经过倒推 \(i\) 步后,有多少个数是模 \(k\) 余 \(j\) 的(即有多少个模 \(k\) 余 \(j\) 的数经过 \(i\) 步操作后变成 \(1\)),容易得到初始矩阵 \(G\) 和转移矩阵 \(A\):(矩阵的行列从 \(0\) 开始编号)
那么我们要求的是 \(GA^{m+1}\) 的第 \(0\) 列所有数的和(即 \(\sum\limits_{i=0}^{k-1}f_{m+1,i}\)),即 \(GA^{m+2}\) 的第 \(0\) 列第 \(0\) 行的数。
我们先得到 \(A\) 的特征多项式 \(f(\lambda)=|\lambda I-A|=\lambda^k-\sum\limits_{i=0}^{k-1}\lambda^i=\lambda^k-\lambda^{k-1}-\lambda^{k-2}-\cdots-\lambda^0\)。
由某个 C 开头的定理知 \(f(A)=0\)。
由于我们要求 \(A^{m+2}\),所以我们不妨将 \(A^{m+2}\) 一直减去 \(f(A)\) 直到次数恰好小于 \(k\) 为止(也就是取模)。
不妨令 \(g(\lambda)=\lambda^{m+2}\bmod f(\lambda)\)(这个可以用边快速幂边取模的方法求出。注意 \(f(\lambda)\) 的形式很特殊,可以 \(O(k)\) 简单取模)。将 \(g(\lambda)\) 展开,设 \(g(\lambda)=\sum\limits_{i=0}^{k-1}a_i\lambda^i\)。
那么 \(A^{m+2}=A^{m+2}\bmod f(A)=g(A)=\sum\limits_{i=0}^{k-1}a_iA^i\)。
那么我们要求的就是:
注意到 \(\left(GA^i\right)_{0,0}\) 就是 \(f_{i,0}\),所以原式即为:\(\left(GA^{m+2}\right)_{0,0}=\sum\limits_{i=0}^{k-1}a_if_{i,0}\)。
所以我们只需要知道 \(0\leq i<k\) 的 \(f_{i,0}\) 即可。
但是暴力预处理是 \(O(k^2)\) 的,还是太大了。
但容易发现在 \(0\leq i <k\) 的情况下,\(f_{i,0}=2^i\)。
大概是因为在操作步数 \(<k\) 的情况下,不可能出现当前数模 \(k\) 余 \(k-1\) 向外转移的情况,所以每个数都有两种向外的转移方法。
那么就能在 \(O(k\log k\log m)\) 的时间内求出 \(\left(GA^{m+2}\right)_{0,0}\) 了。
一开始我们说过,这样会算多,因为我们是从 \(0\) 开始倒推的。
所以我们需要让倒推的第一步强制选 \(+1\)(即强制让正推的最后一步是 \(1-1\to 0\))。(注意这里保证了 \(k>1\),如果 \(k=1\) 的话请在程序开始时特判输出 \(1\))
观察 \(f_{1}\) 所对应的矩阵:
理论上来说,我们需要强制让除 \(f_{1,1}=1\) 之外的其他位置都是 \(0\)。
所以我们要减去 \(f_{1,0}=1\) 时这个 \(1\) 对之后的矩乘的贡献。
也就是说我们要减去 \(GA^m\) 第 \(0\) 列所有数的和,即 \(\left(GA^{m+1}\right)_{0,0}\)。
所以真正的答案应该是 \(\left(GA^{m+2}\right)_{0,0}-\left(GA^{m+1}\right)_{0,0}\)。
总时间复杂度 \(O(k\log k\log m)\)。
关于取模的问题:
由于 \(mod\leq 300\) 很小,所以在快速幂中:多项式乘法时用 NTT 在模一个很大的模数 \(M\)(如 \(M=998244353\))意义下进行,目的是让 NTT 中的模不起作用(即乘出来的系数不可能大于等于 \(M\),模了 \(M\) 和没模一样);多项式取模时再用 \(mod\) 模,因为此时的多项式取模可以不用 NTT,\(O(k)\) 做。
代码如下:
#include<bits/stdc++.h>
#define LN 16
#define N 10010
#define ll long long
using namespace std;
const int M=998244353;
namespace modular
{
inline int add(const int x,const int y,const int mod=M){return x+y>=mod?x+y-mod:x+y;}
inline int dec(const int x,const int y,const int mod=M){return x-y<0?x-y+mod:x-y;}
inline int mul(const int x,const int y,const int mod=M){return 1ll*x*y%mod;}
}using namespace modular;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
inline int poww(int a,int b,const int mod=M)
{
int ans=1;
while(b)
{
if(b&1) ans=mul(ans,a);
a=mul(a,a);
b>>=1;
}
return ans;
}
int k,mod;
int limit,rev[N<<2],w[LN][N<<2][2];
int a[N<<2];
ll m;
void init(int limit)
{
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
int len=mid<<1;
int gn=poww(3,(M-1)/len);
int ign=poww(gn,M-2);
int g=1,ig=1;
for(int j=0;j<mid;g=mul(g,gn),ig=mul(ig,ign),j++)
w[bit][j][0]=g,w[bit][j][1]=ig;
}
}
void NTT(int *a,int limit,int opt)
{
opt=(opt<0);
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)*(limit>>1));
for(int i=0;i<limit;i++)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int bit=0,mid=1;mid<limit;bit++,mid<<=1)
{
for(int i=0,len=mid<<1;i<limit;i+=len)
{
for(int j=0;j<mid;j++)
{
int x=a[i+j],y=mul(w[bit][j][opt],a[i+mid+j]);
a[i+j]=add(x,y),a[i+mid+j]=dec(x,y);
}
}
}
if(opt)
{
int tmp=poww(limit,M-2);
for(int i=0;i<limit;i++)
a[i]=mul(a[i],tmp);
}
}
void modmul(int *f,int *g)
{
static int A[N<<2],B[N<<2],sum[N<<3];
for(int i=0;i<limit;i++) A[i]=f[i],B[i]=g[i];
NTT(A,limit,1),NTT(B,limit,1);
for(int i=0;i<limit;i++) A[i]=mul(A[i],B[i]);
NTT(A,limit,-1);
for(int i=0;i<limit;i++) A[i]%=mod;
for(int i=limit-1;i>=k;i--)
sum[i]=add(sum[i+1],add(A[i],dec(sum[i+1],sum[i+k+1],mod),mod),mod);
for(int i=k-1;i>=0;i--)
f[i]=add(A[i],dec(sum[k],sum[i+k+1],mod),mod);
for(int i=0;i<limit;i++) A[i]=B[i]=sum[i]=0;
}
void work(int *ans,ll b)
{
static int now[N<<2];
ans[0]=1,now[1]=1;
while(b)
{
if(b&1ll) modmul(ans,now);
modmul(now,now);
b>>=1ll;
}
}
int main()
{
scanf("%d%lld%d",&k,&m,&mod);
if(k==1)
{
puts("1");
return 0;
}
limit=1;
while(limit<((k+1)<<1)) limit<<=1;
init(limit);
work(a,m+1);
int ans1=0;
for(int i=0,tmp=1;i<k;i++)
{
ans1=add(ans1,mul(a[i],tmp,mod),mod);
if(i) tmp=add(tmp,tmp,mod);
}
for(int i=k;i>=1;i--) a[i]=a[i-1];
a[0]=0;
for(int i=0;i<k;i++) a[i]=add(a[i],a[k],mod);
int ans2=0;
for(int i=0,tmp=1;i<k;i++)
{
ans2=add(ans2,mul(a[i],tmp,mod),mod);
if(i) tmp=add(tmp,tmp,mod);
}
printf("%d\n",dec(ans2,ans1,mod));
return 0;
}
/*
2 4 31
*/