洛谷 4723 【模板】线性递推——常系数线性齐次递推
题目:https://www.luogu.org/problemnew/show/P4723
题解:https://www.luogu.org/problemnew/solution/P4723
特征多项式:\( f(x) = x^k - \sum\limits_{i=1}^{k}f_i x^{n-i} \)
这个多项式是转移矩阵 M 的化零多项式,所以 \( M^n \) 可以对该多项式取模,从而化成 \( M^n = \sum\limits_{i=0}^{k-1} g_i M^i \)
如果知道 \( g_i \) ,考虑给上式两边乘上初始的系数矩阵,得到 \( A(x)*M^n = \sum\limits_{i=0}^{k-1}g_i A(x) M^i \)
左边关注的只有该行向量的第一个位置的值。把所有行向量 “\(A(x)*M^i\)” 都改成其第一个位置的值(据说也是成立的?),发现左边是 a[n] 、右边是 a[i] 。
所以 \( a[n] = \sum\limits_{i=0}^{k-1} g_i a[i] \)
考虑求出 \( g_i \) 。
其实就是多项式 \( x^n \) 对多项式 \( x^k - \sum\limits_{i=1}^{k}f_ix^{n-i} \) 取模得到的多项式。
用多项式 \( x \) 快速幂乘出 \( x^n \) 。一边乘一边对该多项式取模即可。
关于多项式取模的注意事项:
\( A(x) = G(x)B(x)+R(x) \),其中 \(A(x)\)是 n 次,\(G(x)\)是 m 次的模数,\(B(x)\)是 n-m 次的商,\(R(x)\) 是 m-1 次的余数。
\( A^R(x) = G^R(x)B^R(x) + x^{n-m+1}R^R(x) \)
\( A^R(x) = G^R(x)B^R(x) ( mod x^{n-m+1} ) \)
\( B^R(x) = \frac{ A^R(x) }{ G^R(x) } ( mod x^{n-m+1} ) \)
注意这一步!!!
1. \(A^R(x)\) 是先翻转之后再对 \(x^{n-m+1}\) 取模。当然不对之取模也可。
2.可以预处理 \(G^R(x)\) 的逆元,求逆就是在 mod \(x^{n-m+1}\) 意义下的;但注意在求逆之前,原始的数组不要对 \( x^{n-m+1} \) 取模!
3.得出的 \(B^R(x)\) 需要取一下模。注意是先取了模,在翻转回去得到 \(B(x)\)
4.各种时刻要把临时数组使劲清空!!!!!!一直清空到 len 的程度。如果 len 变化了,还要一直清空到新的 len 的范围为止!!!
得到 \(B(x)\) 之后 \( R(x) = A(x) - G(x)*B(x) \)
注意这里没有取模。但是得出的 \( R(x) \) 应该是 m-1 次多项式。需要算好之后手动把 m 次项及以后的系数清零。
注意如果没有把 \( A(x) \) 化成点值,就不要写成 for( i=0;i<len;i++ ) ta[ i ] = a[ i ] - ta[ i ]*g[ i ] ; !!!
可以预处理 \( G(x) \) 的逆元。反正 len 只有两种。
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=(1<<17)+5,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} int n,n2,m,f[N],h[N],g[N],ig[N],tp[N]; int len,r[N],wn[N],wn2[N],inv[N]; int ta[N],ret[N]; void Rev(int *a,int n) { int k=(n+1)>>1; for(int i=0;i<k;i++)swap(a[i],a[n-i]);} void ntt_pre() { int lm=(1<<17); for(int R=2;R<=lm;R<<=1) { wn[R]=pw(3,(mod-1)/R); wn2[R]=pw(3,(mod-1)-(mod-1)/R); inv[R]=pw(R,mod-2); } } void ntt_r() { for(int i=0,j=len>>1;i<len;i++) r[i]=(r[i>>1]>>1)+((i&1)?j:0); } void ntt(int *a,bool fx) { for(int i=0;i<len;i++) if(i<r[i])swap(a[i],a[r[i]]); for(int R=2;R<=len;R<<=1) { int Wn=(fx?wn2[R]:wn[R]); for(int i=0,m=R>>1;i<len;i+=R) for(int j=0,w=1;j<m;j++,w=(ll)w*Wn%mod) { int x=a[i+j],y=(ll)w*a[i+m+j]%mod; a[i+j]=upt(x+y); a[i+m+j]=upt(x-y); } } if(!fx)return; int iv=inv[len]; for(int i=0;i<len;i++)a[i]=(ll)a[i]*iv%mod; } void get_inv(int *a,int lm) { memset(ret,0,sizeof ret); ret[0]=pw(a[0],mod-2); for(int t=2,yt=1,i,j;yt<lm;yt=t,t=len) { len=t<<1; ntt_r(); for(i=0;i<t;i++)ta[i]=a[i];for(;i<len;i++)ta[i]=0; ntt(ret,0); ntt(ta,0); for(i=0;i<len;i++) ret[i]=(ll)ret[i]*upt(2-(ll)ret[i]*ta[i]%mod)%mod; ntt(ret,1); for(i=t;i<len;i++)ret[i]=0; } for(int i=lm;i<len;i++)ret[i]=0; memcpy(a,ret,sizeof ret); } void Mul(int *a,int *b)//(m-1) { for(len=1;len<=n2;len<<=1); ntt_r(); memcpy(ta,b,sizeof b); ntt(a,0); ntt(ta,0); for(int i=0;i<len;i++)a[i]=(ll)a[i]*ta[i]%mod; ntt(a,1); } void get_mod(int *a)//n2 % m { int d=n2-m+1; for(len=1;len<d<<1;len<<=1); ntt_r(); memcpy(ta,a,sizeof 4*(n2+1)); Rev(ta,n2); for(int i=d;i<=len;i++)ta[i]=0; //i<=len not n2//rev before mod //not mod is ok,but clear (n2,len)!!! ntt(ta,0); for(int i=0;i<len;i++)ta[i]=(ll)ta[i]*ig[i]%mod; ntt(ta,1); for(int i=d;i<len;i++)ta[i]=0; Rev(ta,d-1);//mod before rev for(len=1;len<=n2;len<<=1); ntt_r(); for(int i=d;i<len;i++)ta[i]=0;//////new len ntt(ta,0); for(int i=0;i<len;i++) ta[i]=(ll)ta[i]*g[i]%mod; //ta[i]=upt((a[i]-(ll)ta[i]*g[i])%mod); ntt(ta,1); for(int i=0;i<m;i++)ta[i]=upt(a[i]-ta[i]); for(int i=m;i<len;i++)ta[i]=0;// memcpy(a,ta,sizeof ta); } int main() { n=rdn();m=rdn(); n2=2*(m-1); ntt_pre(); for(int i=1;i<=m;i++)f[i]=rdn(); for(int i=0;i<m;i++)h[i]=upt(rdn()); for(int i=1;i<=m;i++)g[m-i]=ig[m-i]=upt(-f[i]); g[m]=ig[m]=1; Rev(ig,m); int d=n2-m+1; get_inv(ig,d); for(len=1;len<d<<1;len<<=1); ntt_r(); ntt(ig,0);//d<<1 for(len=1;len<=n2;len<<=1); ntt_r(); ntt(g,0); memset(f,0,sizeof f); f[0]=tp[1]=1; while(n) { if(n&1){ Mul(f,tp);get_mod(f);} for(len=1;len<=n2;len<<=1); ntt_r(); ntt(tp,0); for(int i=0;i<len;i++)tp[i]=(ll)tp[i]*tp[i]%mod; ntt(tp,1); get_mod(tp); n>>=1; } int ans=0; for(int i=0;i<m;i++) ans=(ans+(ll)f[i]*h[i])%mod; printf("%d\n",ans); return 0; }