玩游戏

Description

对于 \(1\leq n,m,k\leq 10^5\) ,给定 \({a_n}\)\({b_m}\),对所有 \(t\in[1,k]\)

\[\frac{1}{nm}\sum_{i=1}^n \sum_{j=1}^m (a_i+b_j)^t \]

Solution

展开后面,有

\[\begin{align} &c_t=\sum_{i=1}^n \sum_{j=1}^m \sum_{r=0}^t \binom{t}{r} a_i^r b_j^{t-r}\\ =&\sum_{r=0}^t \binom{t}{r} \Big( \sum_{i=1}^n a_i^r\Big) \Big( \sum_{j=1}^m b_j^{t-r}\Big) \end{align} \]

\(A(x)=\sum_{w\geq 0} (\sum_{i=1}^n a_i^w) x^w\),有

\[\begin{align} A(x)&=\sum_{i=1}^n \sum_{w\geq 0} a_i^w x^w \\ &=\sum_{i=1}^n \frac{1}{1-a_ix} \\ &=n-\sum_{i=1}^n \frac{-a_ix}{1-a_ix} \\ &=n-x\sum_{i=1}^n \frac{-a_i}{1-a_ix} \\ &=n-x\sum_{i=1}^n \ln (1-a_ix)' \\ &=n-x\ln\Big(\prod_{i=1}^n (1-a_ix)\Big)' \end{align} \]

于是只需要分治乘法和 \(\ln\) 即可求出 \(A(x)\),再求出其 EGF 形式。\(B\)\(A\) 同理。最后 \(C\) 就是 \(A\)\(B\) 的二项卷积。不要忘了除 \(nm\)

#include<stdio.h>
#define rint register int

typedef long long ll;

inline int read(){
    int x=0,flag=1; char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')flag=0;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-48;c=getchar();}
    return flag? x:-x; 
}

const int N=(1<<21)+7;
const int Mod=998244353;
const int G=3;

ll qpow(ll x,int y=Mod-2){
    ll ret=1;
    while(y){
        if(y&1) ret=ret*x%Mod;
        x=x*x%Mod,y>>=1; 
    }
    return ret;
}

const int Gi=qpow(G);

int rk[N];
inline void swap(ll &x,ll &y){x^=y,y^=x,x^=y;}
void NTT(bool op,int n,ll *F){
    for(rint i=0;i<n;i++)
        if(i<rk[i]) swap(F[i],F[rk[i]]);
    for(rint p=2;p<=n;p<<=1){
        rint len=p>>1;
        ll w=qpow(op? G:Gi,(Mod-1)/p);
        for(rint k=0;k<n;k+=p){
            ll now=1;
            for(rint l=k;l<k+len;l++){
                ll t=F[l+len]*now%Mod;
                F[l+len]=(F[l]-t+Mod)%Mod;
                F[l]=(F[l]+t)%Mod;
                now=now*w%Mod;
            }
        }
    }
}

inline void Cop(int n,ll *a,ll *b){for(int i=0;i<n;i++)a[i]=b[i];}
inline void Clear(int n,ll *F){for(int i=0;i<n;i++)F[i]=0;}
inline void Rk(int n){for(int i=0;i<n;i++)rk[i]=(rk[i>>1]>>1)|(i&1? n>>1:0);}

void Mul(ll *X,int n,int m,ll *a,ll *b){
    static ll x[N],y[N];
    Cop(n+1,x,a),Cop(m+1,y,b);
    for(m+=n,n=1;n<=m;n<<=1); Rk(n);
    NTT(1,n,x),NTT(1,n,y);
    for(rint i=0;i<n;i++) x[i]=x[i]*y[i]%Mod;
    NTT(0,n,x); ll inv=qpow(n);
    for(rint i=0;i<=m;i++) X[i]=x[i]*inv%Mod;
    Clear(n,x),Clear(n,y);
}

void Inv(int n,ll *a,ll *b){
    static ll x[N];
    if(n==1){b[0]=qpow(a[0]);return;}
    Inv((n+1)>>1,a,b); int m=n;
    for(n=1;n<(m<<1);n<<=1); Rk(n);
    Clear(n,x),Cop(m,x,a);
    NTT(1,n,x),NTT(1,n,b);
    for(rint i=0;i<n;i++)
        b[i]=b[i]*(2-x[i]*b[i]%Mod+Mod)%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<m;i++) b[i]=b[i]*inv%Mod;
    for(rint i=m;i<n;i++) b[i]=0;
}

void ln(int n,ll *a,ll *b){
    static ll x[N];
    Clear(n<<1,x); Inv(n,a,x);
    for(rint i=0;i<n-1;i++)
        b[i]=a[i+1]*(i+1)%Mod; b[n-1]=0;
    Mul(x,n-1,n-1,b,x);
    for(rint i=1;i<n;i++) b[i]=x[i-1]*qpow(i)%Mod; b[0]=0;
}

void Solve(int l,int r,ll *a,ll *b){
    int len=r-l+1;
    if(l==r){b[0]=1,b[1]=(Mod-a[l])%Mod;return;}
    int mid=(l+r)>>1,n=1; for(;n<=len;n<<=1);
    ll Lf[N],Rf[N];
    Clear(n,Lf),Clear(n,Rf);
    Solve(l,mid,a,Lf),Solve(mid+1,r,a,Rf); Rk(n);
    NTT(1,n,Lf),NTT(1,n,Rf);
    for(rint i=0;i<n;i++) b[i]=Lf[i]*Rf[i]%Mod;
    NTT(0,n,b); ll inv=qpow(n);
    for(rint i=0;i<=len;i++) b[i]=b[i]*inv%Mod;
    for(rint i=len+1;i<n;i++) b[i]=0;
}

int n,m;
ll a[N],b[N],A[N],B[N],fac[N],inv[N];

inline int max(int x,int y){return x>y? x:y;}

int main(){
    n=read(),m=read();
    for(rint i=1;i<=n;i++) a[i]=read();
    for(rint i=1;i<=m;i++) b[i]=read();
    int t=read(); int k=max(max(n,m),t); 
    Solve(1,n,a,A),Solve(1,m,b,B);
    ln(k+2,A,a),ln(k+2,B,b); 
    for(rint i=0;i<=k;i++)
        a[i]=a[i+1]*(i+1)%Mod,b[i]=b[i+1]*(i+1)%Mod;
    a[k+1]=b[k+1]=0;
    A[0]=n; B[0]=m;
    for(rint i=1;i<=k;i++)
        A[i]=(Mod-a[i-1])%Mod,B[i]=(Mod-b[i-1])%Mod;
    fac[0]=1; int rg=k<<1;
    for(int i=1;i<=rg;i++) fac[i]=fac[i-1]*i%Mod;
    inv[rg]=qpow(fac[rg]);
    for(rint i=rg-1;~i;i--) inv[i]=inv[i+1]*(i+1)%Mod;
    for(rint i=0;i<=k;i++)
        A[i]=A[i]*inv[i]%Mod,B[i]=B[i]*inv[i]%Mod;
    ll ret=1ll*n*m%Mod;
    for(m=rg,n=1;n<=m;n<<=1); Rk(n);
    NTT(1,n,A),NTT(1,n,B);
    for(rint i=0;i<n;i++) A[i]=A[i]*B[i]%Mod;
    NTT(0,n,A); ll inv_=qpow(n)*qpow(ret)%Mod;
    for(rint i=1;i<=t;i++) printf("%lld\n",A[i]*inv_%Mod*fac[i]%Mod);
}
posted @ 2021-07-18 14:33  Kreap  阅读(46)  评论(1编辑  收藏  举报