洛谷P5349 幂

题目描述

题解

化式子$$ans=\sum_{n=0}^∞f_nr^n=\sum_{n=0}^∞r^n\sum_{i=0}^ma_in^i=\sum_{i=0}^ma_i\sum_{n=0}^∞r^nn^i$$
设 $f_i(r)=\sum_{n=0}^∞r^nn^i$ ,则 $rf_i(r)=\sum_{n=0}^∞r^{n+1}n^i=\sum_{n=1}^∞r^n(n-1)^i$
所以$$(1-r)f_i(r)=\sum_{n=1}^∞r^n(n^i-(n-1)^i)$$$$=r\sum_{n=0}^∞r^n((n+1)^i-n^i)=r\sum_{n=0}^∞r^n\sum_{j=0}^{i-1}(_j^i)n^j$$$$=r\sum_{j=0}^{i-1}(_j^i)\sum_{n=0}^∞r^nn^j=r\sum_{j=0}^{i-1}(_j^i)f_j(r)$$
所以$$f_i(r)=\frac{r}{1-r}\sum_{j=0}^{i-1}(_j^i)f_j(r)$$$$\frac{f_i(r)}{i!}=\frac{r}{1-r}\sum_{j=0}^{i-1}\frac{f_j(r)}{j!}\frac{1}{(i-j)!}$$
于是可以分治 $Ntt$ 啦,效率: $O(nlog^2n)$

代码

#include <bits/stdc++.h>
using namespace std;
const int N=8e5+5,P=998244353;
int a[N],s,m,n,V,U,t,p,A[N],B[N],re[N],f[N],ny[N],jc[N],G[2]={3,(P+1)/3};
void pre(int l){
    for (t=1,p=0;t<l;t<<=1,p++);
    for (int i=0;i<t;i++)
        re[i]=(re[i>>1]>>1)|((i&1)<<(p-1));
}
int X(int x){return x>=P?x-P:x;}
int K(int x,int y){
    int z=1;
    for (;y;y>>=1,x=1ll*x*x%P)
        if (y&1) z=1ll*z*x%P;
    return z;
}
void Ntt(int *a,int o){
    for (int i=0;i<t;i++)
        if (i<re[i]) swap(a[i],a[re[i]]);
    for (int wn,i=1;i<t;i<<=1){
        wn=K(G[o],(P-1)/(i<<1));
        for (int x,y,j=0;j<t;j+=(i<<1))
            for (int w=1,k=0;k<i;k++,w=1ll*w*wn%P)
                x=a[j+k],y=1ll*w*a[i+j+k]%P,
                a[j+k]=X(x+y),a[i+j+k]=X(x-y+P);
    }
    if (o)
        for (int i=0,v=K(t,P-2);i<t;i++)
            a[i]=1ll*a[i]*v%P;
}
void solve(int l,int r){
    if (l==r){if (!l) f[l]=V;return;}
    int mid=(l+r)>>1;solve(l,mid);
    for (int i=l;i<=mid;i++) A[i-l]=f[i];
    for (int i=0;i<=r-l+1;i++) B[i]=ny[i];
    pre(mid-l+r-l+3);Ntt(A,0);Ntt(B,0);
    for (int i=0;i<t;i++)
        A[i]=1ll*A[i]*B[i]%P;Ntt(A,1);
    for (int i=mid+1;i<=r;i++)
        f[i]=X(f[i]+1ll*A[i-l]*U%P);
    for (int i=0;i<t;i++) A[i]=B[i]=0;
    solve(mid+1,r);
}
int main(){
    cin>>m>>n;jc[0]=1;
    V=K(X(1-n+P),P-2);U=1ll*n*V%P;
    for (int i=0;i<=m;i++)
        scanf("%d",&a[i]);
    for (int i=1;i<=m;i++)
        jc[i]=1ll*i*jc[i-1]%P;
    ny[m]=K(jc[m],P-2);
    for (int i=m;i;i--)
        ny[i-1]=1ll*i*ny[i]%P;
    solve(0,m);
    for (int i=0;i<=m;i++)
        s=X(s+1ll*f[i]*jc[i]%P*a[i]%P);
    cout<<s<<endl;return 0;
}

 

posted @ 2020-02-09 22:24  xjqxjq  阅读(135)  评论(0编辑  收藏  举报