星星之火

[luogu P5349] 幂 解题报告 (分治FFT)

interlinkage:

https://www.luogu.org/problemnew/show/P5349

description:

solution:

设$g(x)=\sum_{n=0}^{∞}n^xr^n$

$rg(x)=\sum_{n=0}^{∞}n^xr^{n+1}=\sum_{n=1}^{∞}(n-1)^xr^n$

$g(x)=\sum_{n=1}^{∞}n^xr^n(x>0)$(注意$x>0$这个条件,$x=0$的时候这个不符合)

$(1-r)g(x)=\sum_{n=1}^{∞}(n^x-(n-1)^x)r^n=r\sum_{n=0}^{∞}r^n((n+1)^x-n^x)=r\sum_{n=0}^{∞}r^n\sum_{j=0}^{x-1}\dbinom{x}{j}n^j$

$=r\sum_{j=0}^{x-1}\dbinom{x}{j}\sum_{n=0}^{∞}n^jr^n=r\sum_{j=0}^{x-1}\dbinom{x}{j}g(j)$

于是$g(x)=\frac{r}{1-r}\sum_{j=0}^{x-1}\dbinom{x}{j}g(j)$

继续化简得到$\frac{g(x)}{x!}=\sum_{j=0}^{x-1}\frac{g(j)}{j!}(\frac{r}{1-r}*\frac{1}{(x-j)!})$

这个显然可以用分治$FFT$来做

值得注意的是$g(0)=\frac{1}{1-r}$,而不是$\frac{r}{1-r}$,因为在这里$0^0$的值实际上是算$1$的

直接分治的话复杂度为$O(nlognlogn)$,多项式求逆时间复杂度为$O(nlogn)$

code:

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;

const int N=4e5+15;
const ll mo=998244353;
int m;
ll r;
ll a[N],wn[N],R[N],fac[N],inv[N];
inline ll read()
{
    char ch=getchar();ll s=0,f=1;
    while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9') {s=(s<<3)+(s<<1)+ch-'0';ch=getchar();}
    return s*f;
}
ll qpow(ll a,ll b)
{
    ll re=1;
    for (;b;b>>=1,a=a*a%mo) if (b&1) re=re*a%mo;
    return re;
}
void pre()
{
    for (int i=0;i<=25;i++)
    {
        ll t=1ll<<i;
        wn[i]=qpow(3,(mo-1)/t);
    }
}
void ntt(int limit,ll *a,int type)
{
    for (int i=0;i<limit;i++) if (i<R[i]) swap(a[i],a[R[i]]);
    for (int len=1,id=0;len<limit;len<<=1)
    {
        ++id;
        for (int k=0;k<limit;k+=(len<<1))
        {
            ll w=1;
            for (int l=0;l<len;l++,w=w*wn[id]%mo)
            {
                ll Nx=a[k+l],Ny=w*a[k+len+l]%mo;
                a[k+l]=(Nx+Ny)%mo;
                a[k+len+l]=((Nx-Ny)%mo+mo)%mo;
            }
        }
    }
    if (type==1) return;
    for (int i=1;i<limit/2;i++) swap(a[i],a[limit-i]);
    ll inv=qpow(limit,mo-2);
    for (int i=0;i<limit;i++) a[i]=a[i]%mo*inv%mo;
}
ll A[N],B[N];
void cdqfft(ll *a,ll *b,int l,int r)
{
    if (l==r) return;
    int mid=l+r>>1;
    cdqfft(a,b,l,mid);
    
    int limit=1,L=0;
    while (limit<=(r-l+1)*2) limit<<=1,++L;
    for (int i=0;i<=limit;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
    
    for (int j=0;j<=limit;j++) A[j]=0,B[j]=0;
    for (int j=l;j<=mid;j++) A[j-l]=a[j];
    for (int j=0;j<=r-l;j++) B[j]=b[j];
    ntt(limit,A,1);ntt(limit,B,1);
    for (int i=0;i<=limit;i++) A[i]=A[i]*B[i]%mo;
    ntt(limit,A,-1);
    for (int x=mid+1;x<=r;x++) a[x]=(a[x]+A[x-l])%mo;
    cdqfft(a,b,mid+1,r);
}
ll g[N],f[N];
int main()
{
    pre();
    m=read();r=read();
    for (int i=0;i<=m;i++) a[i]=read();
    fac[0]=inv[0]=1;
    for (int i=1;i<=m;i++) fac[i]=fac[i-1]*i%mo;
    inv[m]=qpow(fac[m],mo-2);
    for (int i=m-1;i>=1;i--) inv[i]=inv[i+1]*(i+1)%mo;
    f[0]=qpow(1-r+mo,mo-2)%mo;
    for (int i=1;i<=m;i++) g[i]=inv[i]*f[0]%mo*r%mo;
    cdqfft(f,g,0,m); 
    ll ans=0;    
    for (int i=0;i<=m;i++) ans=(ans+a[i]*f[i]%mo*fac[i]%mo)%mo;
    printf("%lld\n",ans);
    return 0;
}
posted @ 2019-05-06 22:31  星星之火OIer  阅读(267)  评论(0编辑  收藏  举报