[模板] 常系数齐次线性递推

一、题目

点此看题

注意我的写的 \(a\)\(f\) 和题目里面的是反的。

二、解法

我看 \(\tt oiwiki\) 上面的讲解就秒懂了!真的讲得特别特别好!

\(F(\sum c_ix^i)=\sum c_if_i\)\(F(x^n)\) 就是答案。

也就是我们用生成函数第 \(i\) 项作为 \(f_i\) 的记号

由于 \(f_n=\sum_{i=1}^k f_{n-i}a_i\),所以 \(F(x^n)=F(\sum_{i=1}^k a_ix^{n-i})\)

不难发现函数里面也可以直接减的,所以:

\[F(x^n-\sum_{i=1}^ka_ix^{n-i})=F(x^{n-k}(x^k-\sum_{i=0}^{k-1}a_{k-i}x^i)) \]

\(G(x)=x^k-\sum_{i=0}^{k-1}a_{k-i}x^i\),那么就有 \(F(A(x)+x^nG(x))=F(A(x))+F(x^nG(x))=F(A(x))\)

也就是说如果算 \(F(x^n)\) 的话就可以直接取模 \(G(x)\),设 \(P(x)=x^n\bmod G(x)\),那么答案就是 \(F(P(x))\)\(P(x)\) 是一个 \(k-1\) 次多项式所以可以直接根据定义算,用一个快速幂套多项式取模即可,时间复杂度 \(O(k\log n\log k)\)

写出来跑到了 \(\tt luogu\) 倒数第一,很好。

#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define int long long
const int M = 400005;
const int MOD = 998244353;
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,m,k,ans,r[M],a[M],b[M],f[M],g[M];
//b表示x^n模G(x)之后的多项式 
namespace poly//封装了 
{
	int len,A[M],B[M],c[M],d[M],e[M],rev[M];
	int qkpow(int a,int b)
	{
		int r=1;
		while(b>0)
		{
			if(b&1) r=r*a%MOD;
			a=a*a%MOD;
			b>>=1;
		}
		return r;
	}
	void NTT(int *a,int len,int op)
	{
		for(int i=0;i<len;i++)
		{
			rev[i]=(rev[i>>1]>>1)|((i&1)*(len/2));
			if(i<rev[i]) swap(a[i],a[rev[i]]);
		}
		for(int s=2;s<=len;s<<=1)
		{
			int t=s/2,w=(op==1)?qkpow(3,(MOD-1)/s):qkpow(3,MOD-1-(MOD-1)/s);
			for(int i=0;i<len;i+=s)
				for(int j=0,x=1;j<t;j++,x=x*w%MOD)
				{
					int fe=a[i+j],fo=a[i+j+t];
					a[i+j]=(fe+x*fo)%MOD;
					a[i+j+t]=((fe-x*fo)%MOD+MOD)%MOD;
				}
		}
		if(op==1) return ;
		int inv=qkpow(len,MOD-2);
		for(int i=0;i<len;i++) a[i]=a[i]*inv%MOD;
	}
	void work(int n,int *a,int *b)//逆元从属函数 
	{
		len=1;while(len<2*n) len<<=1;
		for(int i=0;i<len;i++) A[i]=B[i]=0;
		for(int i=0;i<n;i++) A[i]=a[i];
		for(int i=0;i<(n/2);i++) B[i]=b[i];
		NTT(A,len,1);NTT(B,len,1);
		for(int i=0;i<len;i++)
			A[i]=((2*B[i]-B[i]*B[i]%MOD*A[i])%MOD+MOD)%MOD;
		NTT(A,len,-1);
		for(int i=0;i<n;i++) b[i]=A[i];
	}
	void inv(int n,int *a,int *b)//逆元存在b那里 
	{
		b[0]=qkpow(a[0],MOD-2);
		int cur=1;
		while(cur<n)
		{
			cur<<=1;
			work(cur,a,b);
		}
	}
	void mul(int n,int *a,int *b)//多项式乘法 
	{
		len=1;while(len<2*n) len<<=1;
		for(int i=0;i<len;i++) A[i]=B[i]=0;
		for(int i=0;i<n;i++) A[i]=a[i],B[i]=b[i];
		NTT(A,len,1);NTT(B,len,1);
		for(int i=0;i<len;i++) A[i]=A[i]*B[i]%MOD;
		NTT(A,len,-1);
		for(int i=0;i<2*n;i++) b[i]=A[i]; 
	}
	void mod(int n,int m,int *a,int *b)
	//n次多项式a取模m次多项式b
	//最后的结果是余数,存在a处
	{
		//翻转A 
		for(int i=0;i<=n;i++) d[i]=a[i];
		for(int i=0;i<=n/2;i++) swap(d[i],d[n-i]);
		//翻转B 
		for(int i=0;i<=m;i++) e[i]=b[i];
		for(int i=0;i<=m/2;i++) swap(e[i],e[m-i]);
		inv(n-m+1,e,c);
		//清除掉无用的部分 
		len=1;while(len<=2*(n-m)) len<<=1;
		for(int i=n-m+1;i<len;i++) d[i]=c[i]=0;
		NTT(c,len,1);NTT(d,len,1);
		for(int i=0;i<len;i++) c[i]=c[i]*d[i]%MOD;
		NTT(c,len,-1);
		for(int i=n-m+1;i<len;i++) c[i]=0;
		for(int i=0;i<=(n-m)/2;i++) swap(c[i],c[n-m-i]);
		//算余数
		len=1;while(len<=n) len<<=1; 
		for(int i=0;i<=m/2;i++) swap(e[i],e[m-i]);//翻转回来 
		NTT(c,len,1);NTT(e,len,1);
		for(int i=0;i<len;i++) c[i]=c[i]*e[i]%MOD;
		NTT(c,len,-1);
		for(int i=0;i<=n;i++)
			a[i]=((a[i]-c[i])+MOD)%MOD;
		for(int i=0;i<len;i++) e[i]=c[i]=0;//用了就清 
	}
};
signed main()
{
	n=read();k=read();
	for(int i=1;i<=k;i++)
		b[i]=read();
	for(int i=0;i<k;i++)
		g[i]=(MOD-b[k-i])%MOD;
	g[k]=1;
	for(int i=0;i<k;i++)
		f[i]=read();
	a[1]=1;//初始值是x
	r[0]=1;//初始值是1 
	while(n>0)
	{
		if(n&1)
		{
			poly::mul(k,a,r);
			poly::mod(2*k-2,k,r,g);
		}
		poly::mul(k,a,a);
		poly::mod(2*k-2,k,a,g);
		n>>=1;
	}
	for(int i=0;i<k;i++)
		ans=(ans+f[i]*r[i])%MOD;
	printf("%lld\n",(ans+MOD)%MOD);
}
posted @ 2021-03-18 19:44  C202044zxy  阅读(110)  评论(0编辑  收藏  举报