题解 洛谷 P4705 【玩游戏】

先化简答案的式子:

\[\large\begin{aligned} ans_k&=\frac{1}{nm}\sum_{i=1}^n\sum_{j=1}^m(a_i+b_j)^k \\ &=\frac{1}{nm}\sum_{i=1}^n\sum_{j=1}^m\sum_{l=0}^k\binom{k}{l}a_i^lb_j^{k-l} \\ &=\frac{k!}{nm}\sum_{l=0}^k\left (\sum_{i=1}^n\frac{a_i^l}{l!}\right )\left (\sum_{j=1}^m\frac{b_j^{k-l}}{(k-l)!} \right ) \end{aligned} \]

发现是卷积的形式,那么只要能快速计算 \(s_k=\sum\limits_{i=1}^n a_i^k\),然后进行卷积计算答案即可。

构造生成函数 \(f(x)=\sum\limits_{i \geqslant 0} s_ix^i\),进行化简:

\[\large f(x)=\sum_{i \geqslant 0} s_ix^i =\sum_{i \geqslant 0} \sum_{j=0}^{n-1} a_j^ix^i =\sum_{j=0}^{n-1}\frac{1}{1-a_jx} \]

注意到:

\[\large {(\ln (1-a_ix))}'= \frac{-a_i}{1-a_ix} \]

代入得:

\[\large\begin{aligned} f(x)&=\sum_{i=0}^{n-1}\frac{1}{1-a_ix} \\ &=\sum_{i=0}^{n-1}1-x{(\ln (1-a_ix))}' \\ &=n-x{\left (\sum_{i=0}^{n-1}\ln (1-a_ix)\right )}' \\ &=n-x\left (\ln\prod_{i=0}^{n-1} (1-a_ix)\right )' \end{aligned} \]

对于 \(\prod\limits_{i=0}^{n-1} (1-a_ix)\),可以通过分治来求解,即先算出左右两半后乘起来。

最后通过多项式对数函数即可求解答案。

#include<bits/stdc++.h>
#define maxn 800010
#define P 998244353
#define G 3
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flag=false;
    while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    if(flag)x=-x;
}
ll n,m,k,all;
ll rev[maxn],a[maxn],b[maxn],f[maxn],g[maxn],fac[maxn],ifac[maxn];
ll qp(ll x,ll y)
{
    ll v=1;
    while(y)
    {
        if(y&1) v=v*x%P;
        x=x*x%P,y>>=1;
    }
    return v;
}
int calc(int n)
{
    int lim=1;
    while(lim<=n) lim<<=1;
    for(int i=0;i<lim;++i)
        rev[i]=(rev[i>>1]>>1)|((i&1)?lim>>1:0);
    return lim;
}
void NTT(ll *a,int lim,int type)
{
    for(int i=0;i<lim;++i)
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    for(int len=1;len<lim;len<<=1)
    {
        ll wn=qp(G,(P-1)/(len<<1));
        for(int i=0;i<lim;i+=len<<1)
        {
            ll w=1;
            for(int j=i;j<i+len;++j,w=w*wn%P)
            {
                ll x=a[j],y=w*a[j+len]%P;
                a[j]=(x+y)%P,a[j+len]=(x-y+P)%P;
            }
        }
    }
    if(type==1) return;
    ll inv=qp(lim,P-2);
    for(int i=0;i<lim;++i) a[i]=a[i]*inv%P;
    reverse(a+1,a+lim);
}
void Inv(int deg,ll *a,ll *b)
{
    static ll t[maxn];
    if(deg==1)
    {
        b[0]=qp(a[0],P-2);
        return;
    }
    Inv((deg+1)>>1,a,b);
    int lim=calc(deg<<1);
    for(int i=0;i<deg;++i) t[i]=a[i];
    for(int i=deg;i<lim;++i) t[i]=b[i]=0;
    NTT(t,lim,1),NTT(b,lim,1);
    for(int i=0;i<lim;++i)
        b[i]=b[i]*(2-t[i]*b[i]%P+P)%P;
    NTT(b,lim,-1);
    for(int i=deg;i<lim;++i) b[i]=0;
}
void Ln(int deg,ll *a,ll *b)
{
    static ll inva[maxn],dera[maxn];
    Inv(deg,a,inva);
    for(int i=0;i<deg-1;++i) dera[i]=a[i+1]*(i+1)%P;
    dera[deg-1]=0;
    int lim=calc(deg<<1);
    for(int i=deg;i<lim;++i) dera[i]=inva[i]=0;
    NTT(dera,lim,1),NTT(inva,lim,1);
    for(int i=0;i<lim;++i) b[i]=dera[i]*inva[i]%P;
    NTT(b,lim,-1);
    for(int i=deg-1;i>=1;--i) b[i]=b[i-1]*qp(i,P-2)%P;
    b[0]=0;
    for(int i=deg;i<lim;++i) b[i]=0;
}
void solve(int l,int r,ll *a,ll *b)
{
    if(l==r)
    {
        b[0]=1,b[1]=P-a[l];
        return;
    }
    int mid=(l+r)>>1;
    ll a1[maxn],a2[maxn];
    solve(l,mid,a,a1),solve(mid+1,r,a,a2);
    int lim=calc(r-l+1);
    for(int i=mid-l+2;i<lim;++i) a1[i]=0;
    for(int i=r-mid+1;i<lim;++i) a2[i]=0;
    NTT(a1,lim,1),NTT(a2,lim,1);
    for(int i=0;i<lim;++i) b[i]=a1[i]*a2[i]%P;
    NTT(b,lim,-1);
    for(int i=r-l+2;i<lim;++i) b[i]=0;
}
void get(int n,ll *a,ll *b)
{
    static ll t[maxn];
    solve(1,calc(all),a,t),Ln(all+1,t,b);
    for(int i=0;i<=all;++i) b[i]=b[i+1]*(i+1)%P;
    b[all+1]=0;
    for(int i=all;i;--i) b[i]=P-b[i-1];
    b[0]=n;
    for(int i=0;i<=all;++i) b[i]=b[i]*ifac[i]%P;
}
void init()
{
    fac[0]=ifac[0]=1;
    for(int i=1;i<=all;++i) fac[i]=fac[i-1]*i%P;
    ifac[all]=qp(fac[all],P-2);
    for(int i=all-1;i;--i) ifac[i]=ifac[i+1]*(i+1)%P;
}
int main()
{
    read(n),read(m);
    for(int i=1;i<=n;++i) read(a[i]);
    for(int i=1;i<=m;++i) read(b[i]);
    read(k),all=max(k,max(n,m)),init();
    get(n,a,f),get(m,b,g);
    int lim=calc(all<<1);
    NTT(f,lim,1),NTT(g,lim,1);
    for(int i=0;i<lim;++i) f[i]=f[i]*g[i]%P;
    NTT(f,lim,-1);
    ll inv=qp(n*m%P,P-2);
    for(int i=1;i<=k;++i) printf("%lld\n",fac[i]*inv%P*f[i]%P);
    return 0;
}
posted @ 2020-08-05 20:04  lhm_liu  阅读(181)  评论(1编辑  收藏  举报