[NOWCODER] myh的超级多项式

题面

已知$f_i=(\sum_{j=1}ka_j{v_j}i )\bmod 1004535809$

给定$v_1,v_2,\ldots,v_k,f_1,f_2,\ldots f_k$

求$f_n$

思路

我们考虑构造一个递推式,使得:

$f_n=\sum_{i=1}^k c_i f_{n-i}$

我们把这个$f_n$挪到右边来,令$c_0=1$,得到:

$\sum_{i=0}^k c_i f_{n-i} =0$

即:

$\sum_{i=0}^k c_i \sum_{j=1}^k a_j v_j^{n-i}=0$

这个式子的一个充分条件(可行条件)

$\forall j \in [1,k] \sum_{i=0}^k c_i a_j v_j^{n-i}=0$

把$a_j$挪到前面去,除掉一部分$v_j$的幂,得到这个式子:

$\forall j \in [1,k] \sum_{i=0}^k c_i v_j^{k-i}=0$

令$F(x)=\sum c_{k-i} x^i$,那么我们发现${v}$数组是$F(x)$的所有0点

又因为$c_0=-1$,所以$F(x)=-\prod_{i=1}^k (x-v_i)$

分治FFT求出$F(x)$,然后用$O((n-k)k)$递推(不会TLE)得到$f_n$即可

Code

代码里有一个技巧

因为一段区间得到的n+1个系数的多项式的最高次项一定是1,所以我们可以不保存他

这样分治FFT用长度为n的数组就能保存了

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MOD 1004535809
#define ll long long
using namespace std;
inline int read(){
	int re=0,flag=1;char ch=getchar();
	while(!isdigit(ch)){
		if(ch=='-') flag=-1;
		ch=getchar();
	}
	while(isdigit(ch)) re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
	return re*flag;
}
ll qpow(ll a,ll b){
	ll re=1;
	while(b){
		if(b&1) re=re*a%MOD;
		a=a*a%MOD;b>>=1;
	}
	return re;
}
ll add(ll a,ll b){
	a+=b;
	return ((a>=MOD)?a-MOD:a);
}
ll dec(ll a,ll b){
	a-=b;
	return ((a<0)?a+MOD:a);
}
ll g=3,ginv;
namespace NTT{
	int lim,cnt,r[400010];
	ll A[400010],B[400010];
	void ntt(ll *a,ll type){
		int i,j,k,mid;ll x,y,w,wn,inv;
		for(i=0;i<lim;i++) if(i<r[i]) swap(a[i],a[r[i]]);
		for(mid=1;mid<lim;mid<<=1){
			wn=qpow(((~type)?g:ginv),(MOD-1)/(mid<<1));
			for(j=0;j<lim;j+=(mid<<1)){
				w=1;
				for(k=0;k<mid;k++,w=w*wn%MOD){
					x=a[j+k];y=a[j+k+mid]*w%MOD;
					a[j+k]=add(x,y);
					a[j+k+mid]=dec(x,y);
				}
			}
		}
		if(~type) return;
		inv=qpow(lim,MOD-2);
		for(i=0;i<lim;i++) a[i]=a[i]*inv%MOD;
	}
	void init(int n){
		int i;
		lim=1;cnt=0;
		while(lim<=n) lim<<=1,cnt++;
		for(i=0;i<lim;i++) r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1))),A[i]=B[i]=0;
	}
}
void mul(){
	using namespace NTT;
	ntt(A,1);ntt(B,1);int i;
	for(i=0;i<lim;i++) A[i]=A[i]*B[i]%MOD;
	ntt(A,-1);
}
ll c[100010];//黑科技数组
int n,k;ll v[100010],f[100010];
void solve(int l,int r){
	if(l==r){
		c[l]=MOD-v[l];
		return;
	}
	int mid=(l+r)>>1,i;
	solve(l,mid);solve(mid+1,r);
	using namespace NTT;
	init(r-l+1);
	for(i=0;i<=mid-l;i++) A[i]=c[i+l];
	for(i=0;i<r-mid;i++) B[i]=c[i+mid+1];
	A[mid-l+1]=B[r-mid]=1;//把没记录的1加上
	mul();
	for(i=0;i<=r-l;i++) c[l+i]=A[i];//这里不保存1
}
int main(){
	n=read();k=read();int i,j;
	g=3;ginv=qpow(3,MOD-2);
	for(i=1;i<=k;i++) v[i]=read();
	for(i=1;i<=k;i++) f[i]=read();
	solve(1,k);
	for(i=0;i<k;i++) c[i]=c[i+1];
	c[k]=1;
	for(i=0;i<=k;i++) if(c[i]) c[i]=MOD-c[i];
	for(i=0;i<=k/2;i++) swap(c[i],c[k-i]);
	for(i=k+1;i<=n;i++){
		ll w=0;
		for(j=1;j<=k;j++) w+=c[j]*f[i-j]%MOD;
		f[i]=w%MOD;
	}
	printf("%lld\n",f[n]);
}
posted @ 2018-09-30 11:52  dedicatus545  阅读(333)  评论(0编辑  收藏  举报