hdu 6900 Residual Polynomial (NTT)
题目链接
http://acm.hdu.edu.cn/showproblem.php?pid=6900
题意
定义\(f_1(x)=\sum_{i=0}^{n}a_ix^i\),给定序列\(a_i,b_i,c_i\),以及递推式\(f_i(x)=b_i(f_{i-1}(x))'+c_if_{i-1}(x)\)
思路
一篇讲个很好的博客:https://www.cnblogs.com/JustinRochester/p/13705300.html
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxx = 1e6+10;
const int mod = 998244353,G=3,G1=332748118;
LL a[maxx],b[maxx],c[maxx],r[maxx];
LL f[maxx],g[maxx];
LL *d[maxx<<2];
LL e[maxx],ans[maxx];
LL p[maxx],invp[maxx];
int limit;
LL quick(LL a,LL b)
{
LL res=1;
while(b)
{
if(b&1)res=res*a%mod;
a=a*a%mod;
b>>=1;
}
return res;
}
void init()
{
p[0]=1;
for(int i=1;i<maxx;i++)
{
p[i]=p[i-1]*i%mod;
invp[i]=quick(p[i],mod-2);
}
}
void NTT(LL *A,int type)
{
for(int i=0;i<limit;i++)
if(i<r[i])swap(A[i],A[r[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
LL wn = quick(type==1?G:G1,(mod-1)/(mid<<1));
for(int j=0;j<limit;j+=(mid<<1))
{
LL w=1;
for(int k=0;k<mid;k++,w=(w*wn)%mod)
{
int x=A[j+k],y=w*A[j+k+mid]%mod;
A[j+k]=(x+y)%mod;
A[j+k+mid]=(x-y+mod)%mod;
}
}
}
}
void mul(LL *a,LL *b,LL *h,int n,int m)
{
int L=0;
limit=1;
while(limit<=n+m)limit<<=1,L++;
for(int i=0;i<limit;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(L-1));
for(int i=0;i<limit;i++)f[i]=g[i]=0;
for(int i=0;i<=n;i++)f[i]=a[i];
for(int i=0;i<=m;i++)g[i]=b[i];
NTT(f,1),NTT(g,1);
for(int i=0;i<limit;i++)h[i]=(f[i]*g[i])%mod;
NTT(h,-1);
LL inv=quick(limit,mod-2);
for(int i=0;i<=n+m;i++)h[i]=h[i]*inv%mod;
}
void solve(int l,int r,int rt)
{
if(l==r)
{
d[rt]=new LL[2];
d[rt][0]=b[l];
d[rt][1]=c[l];
return;
}
int mid=(l+r)/2;
solve(l,mid,rt*2);
solve(mid+1,r,rt*2+1);
d[rt]=new LL[2*(r-l+1)];
mul(d[rt*2],d[rt*2+1],d[rt],mid-l+1,r-mid);
}
void del(int l,int r,int rt)
{
delete d[rt];
if(l==r)return;
int mid=(l+r)/2;
del(l,mid,rt*2);
del(mid+1,r,rt*2+1);
}
int main()
{
init();
int T;
scanf("%d",&T);
while(T--)
{
int n;
scanf("%d",&n);
for(int i=0;i<=n;i++)scanf("%lld",&a[i]),e[i]=a[i]*p[i]%mod;
for(int i=2;i<=n;i++)scanf("%lld",&b[i]);
for(int i=2;i<=n;i++)scanf("%lld",&c[i]);
solve(2,n,1);
mul(d[1],e,ans,n,n);
printf("%lld",ans[n-1]);
for(int i=1;i<=n;i++)printf(" %lld",ans[i+n-1]*invp[i]%mod);
printf("\n");
del(2,n,1);
}
}