[luogu P5349] 幂 解题报告 (分治FFT)
interlinkage:
https://www.luogu.org/problemnew/show/P5349
description:
solution:
设$g(x)=\sum_{n=0}^{∞}n^xr^n$
$rg(x)=\sum_{n=0}^{∞}n^xr^{n+1}=\sum_{n=1}^{∞}(n-1)^xr^n$
$g(x)=\sum_{n=1}^{∞}n^xr^n(x>0)$(注意$x>0$这个条件,$x=0$的时候这个不符合)
$(1-r)g(x)=\sum_{n=1}^{∞}(n^x-(n-1)^x)r^n=r\sum_{n=0}^{∞}r^n((n+1)^x-n^x)=r\sum_{n=0}^{∞}r^n\sum_{j=0}^{x-1}\dbinom{x}{j}n^j$
$=r\sum_{j=0}^{x-1}\dbinom{x}{j}\sum_{n=0}^{∞}n^jr^n=r\sum_{j=0}^{x-1}\dbinom{x}{j}g(j)$
于是$g(x)=\frac{r}{1-r}\sum_{j=0}^{x-1}\dbinom{x}{j}g(j)$
继续化简得到$\frac{g(x)}{x!}=\sum_{j=0}^{x-1}\frac{g(j)}{j!}(\frac{r}{1-r}*\frac{1}{(x-j)!})$
这个显然可以用分治$FFT$来做
值得注意的是$g(0)=\frac{1}{1-r}$,而不是$\frac{r}{1-r}$,因为在这里$0^0$的值实际上是算$1$的
直接分治的话复杂度为$O(nlognlogn)$,多项式求逆时间复杂度为$O(nlogn)$
code:
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=4e5+15;
const ll mo=998244353;
int m;
ll r;
ll a[N],wn[N],R[N],fac[N],inv[N];
inline ll read()
{
char ch=getchar();ll s=0,f=1;
while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9') {s=(s<<3)+(s<<1)+ch-'0';ch=getchar();}
return s*f;
}
ll qpow(ll a,ll b)
{
ll re=1;
for (;b;b>>=1,a=a*a%mo) if (b&1) re=re*a%mo;
return re;
}
void pre()
{
for (int i=0;i<=25;i++)
{
ll t=1ll<<i;
wn[i]=qpow(3,(mo-1)/t);
}
}
void ntt(int limit,ll *a,int type)
{
for (int i=0;i<limit;i++) if (i<R[i]) swap(a[i],a[R[i]]);
for (int len=1,id=0;len<limit;len<<=1)
{
++id;
for (int k=0;k<limit;k+=(len<<1))
{
ll w=1;
for (int l=0;l<len;l++,w=w*wn[id]%mo)
{
ll Nx=a[k+l],Ny=w*a[k+len+l]%mo;
a[k+l]=(Nx+Ny)%mo;
a[k+len+l]=((Nx-Ny)%mo+mo)%mo;
}
}
}
if (type==1) return;
for (int i=1;i<limit/2;i++) swap(a[i],a[limit-i]);
ll inv=qpow(limit,mo-2);
for (int i=0;i<limit;i++) a[i]=a[i]%mo*inv%mo;
}
ll A[N],B[N];
void cdqfft(ll *a,ll *b,int l,int r)
{
if (l==r) return;
int mid=l+r>>1;
cdqfft(a,b,l,mid);
int limit=1,L=0;
while (limit<=(r-l+1)*2) limit<<=1,++L;
for (int i=0;i<=limit;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
for (int j=0;j<=limit;j++) A[j]=0,B[j]=0;
for (int j=l;j<=mid;j++) A[j-l]=a[j];
for (int j=0;j<=r-l;j++) B[j]=b[j];
ntt(limit,A,1);ntt(limit,B,1);
for (int i=0;i<=limit;i++) A[i]=A[i]*B[i]%mo;
ntt(limit,A,-1);
for (int x=mid+1;x<=r;x++) a[x]=(a[x]+A[x-l])%mo;
cdqfft(a,b,mid+1,r);
}
ll g[N],f[N];
int main()
{
pre();
m=read();r=read();
for (int i=0;i<=m;i++) a[i]=read();
fac[0]=inv[0]=1;
for (int i=1;i<=m;i++) fac[i]=fac[i-1]*i%mo;
inv[m]=qpow(fac[m],mo-2);
for (int i=m-1;i>=1;i--) inv[i]=inv[i+1]*(i+1)%mo;
f[0]=qpow(1-r+mo,mo-2)%mo;
for (int i=1;i<=m;i++) g[i]=inv[i]*f[0]%mo*r%mo;
cdqfft(f,g,0,m);
ll ans=0;
for (int i=0;i<=m;i++) ans=(ans+a[i]*f[i]%mo*fac[i]%mo)%mo;
printf("%lld\n",ans);
return 0;
}
星星之火,终将成燎原之势