[ABC215G]Colorful Candies 2 题解

期望

Statement

G - Colorful Candies 2 (atcoder.jp)

给定 \(n\) 个糖果,第 \(i\) 个糖果颜色为 \(c_i\)

对于每个 \(k=1∼n\),求随机选出 \(k\) 个糖果,\(\binom nk\) 种情况中糖果颜色数的期望。答案模 \(998244353\)

\(n\le 5\times 10^4\)

Solution

知道这类问题的一般套路都是先利用期望的线性性转化问题求解

但是不会用期望的线性性,于是最开始设了一个 \(dp[i][j][k]\) 表示前 \(i\) 种颜色中选择 \(j\) 个,选出 \(k\) 不同颜色的方案数,然后前缀和优化之后得到一个 \(O(n^3)\) 后发现状态实在减不动了,GG

回到正题,注意到本题中对答案有影响的是不同颜色的数量和每种颜色数量,并不关注每种颜色具体是什么

枚举 \(k\) ,题目要求计算 \(E(\sum x_i)\)\(x_i=1/0\) 表示选择 \(k\) 个数,颜色 \(i\) 是/否被选中

\(E(\sum x_i)=\sum E(x_i)\) ,考虑求出选择 \(k\) 个数中有 \(i\) 的期望

简单容斥一下,用总方案数-选不中方案数,设颜色 \(i\)\(cnt_i\)

\[E(x_i)=\dfrac{\binom nk-\binom {n-cnt_i}k}{\binom nk} \]

暴力计算是 \(O(n^2)\) 的,考虑优化

注意到其实 \(cnt\) 相同的可以和在一起算,又 \(\sum cnt_i=n\) ,所以不同的 \(cnt\) 只有 \(\sqrt n\) 种,复杂度来到 \(O(n\sqrt n)\)

Code

#include<bits/stdc++.h>
using namespace std;
const int N = 5e4+5;
const int mod = 998244353;

char buf[1<<23],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
int read(){
    int s=0,w=1; char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();}
    while(isdigit(ch))s=s*10+(ch^48),ch=getchar();
    return s*w;
}
int ksm(int a,int b){
    int res=1;
    while(b){
        if(b&1)res=1ll*res*a%mod;
        a=1ll*a*a%mod,b>>=1;
    }
    return res;
}
void inc(int &a,int b){a=a>=mod-b?a-mod+b:a+b;}

vector<int>vec;
int jc[N],invj[N],a[N],b[N];
int tong[N],cnt[N];
bool vis[N];
int n,m;

int C(int n,int m){
    if(n<m||m<0)return 0;
    return 1ll*jc[n]*invj[m]%mod*invj[n-m]%mod;
}

signed main(){
    n=read();
    for(int i=jc[0]=1;i<=n;++i)
        jc[i]=1ll*jc[i-1]*i%mod,a[i]=b[i]=read();
    invj[n]=ksm(jc[n],mod-2);
    for(int i=n-1;~i;--i)invj[i]=1ll*invj[i+1]*(i+1)%mod;
    sort(b+1,b+1+n),m=unique(b+1,b+1+n)-b-1;
    for(int i=1;i<=n;++i)
        tong[lower_bound(b+1,b+1+m,a[i])-b]++;
    for(int i=1;i<=m;++i){
        cnt[tong[i]]++;
        if(vis[tong[i]]==0)
            vec.push_back(tong[i]),vis[tong[i]]=1;
    }
    for(int k=1;k<=n;++k){
        int ans=0,all=C(n,k),invl=ksm(all,mod-2);
        for(auto v:vec)
            inc(ans,1ll*cnt[v]*((all-C(n-v,k)+mod)%mod)%mod*invl%mod);
        printf("%d\n",ans);
    }
    return 0;
}
posted @ 2022-04-19 19:50  _Famiglistimo  阅读(55)  评论(0编辑  收藏  举报