atcorder 295 E
题目链接:https://atcoder.jp/contests/abc295/tasks/abc295_e
题意:
给定一个长为N的数字序列,序列中每个数字都在[0, M]这个区间中。按顺序做两步操作:
第一步,对于数字序列中每个数字0,独立的并且等概率的从区间[1, M]中选择一个数, 把这个0代替成选出来的数。
第二步,把这个数字序列按照升序排列。
问第K位得到数字的期望是什么,输出答案mod998244353
Simple input
3 5 2
2 0 4
Simple output
3
Solution:
第K位的期望是:\(E_k = \displaystyle\sum^{M}_{i = 1}{i * p_i}\)
这样不是很好求,可以将其转化成\(E_k = \displaystyle\sum^{M}_{i = 1}{i * (b_i - b_{i + 1})}\)
其中\(b_i\)指的是第k位大于等于i的概率。
那么\(E_k = 1 * (b_1 - b_2) + 2 * (b_2 - b_3) + \dots + m - 1 * (b_{m - 1} - b_{m}) + b_m = b_1 + b_2 + \dots + b_m\)
接下来就要求出\(b_i\)
分类讨论:
1.:如果序列中大于等于i的数的个数是大于等于n - k + 1的那么\(b_i\)就是1,因为这种情况在不改变0的情况下已经使得在升序排序后第k位恒大于等于i。
2.:如果序列中大于等于i的个数不足n - k + 1,那么就需要把一些0变成大于等于i的数。
对于这种情况,我们事先统计好其中大于等于i的数的个数,设为cnt,并且统计出来其中0的个数,设为num。
如果num + cnt < n - k + 1, 意思就是说,把所有0都转换成大于等于i的数,仍然无法使得第k位的数大于等于i,因此\(b_i\) = 0。
而如果num + cnt >= n - k + 1, 我们就可以从num个0中选出来需要进行转化的0。
对于每一个0,它转换成大于等于i的数的概率是\(P = \frac{m - i + 1}{m}\)
那么,这种情况\(b_i = \displaystyle\sum^{num}_{i = n - k + 1 - cnt}\left({i\choose num} * P^i * (1 - P)^{num - i}\right)\)
就是加法原理和乘法原理。
Code:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define int LL
const int mod = 998244353;
const int N = 5010;
int c[N][N];
int qmi(int a, int b, int c) {
LL res = 1;
while(b) {
if(b & 1) res = res * a % c;
a = (LL) a * a % c;
b >>= 1;
}
return res;
}
signed main()
{
// ios::sync_with_stdio(false);
// cin.tie(0);
for(int i = 0; i <= 5005; i ++) {
for(int j = 0; j <= i; j ++) {
if(!j) c[i][j] = 1;
else c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
}
}
int n, m, k; cin >> n >> m >> k;
vector<int> seq(n + 1);
vector<int> stc(2010, 0);
int num = 0;//统计序列中零的个数
for(int i = 1; i <= n; i ++) {
cin >> seq[i];
if(!seq[i]) num ++;
stc[seq[i]] ++;
}
int need = n - k + 1;
int ans = 0;
for(int i = 1; i <= m; i ++) {
int cnt = 0;//统计序列中有多少个数字大于等于i
for(int j = i; j <= m; j ++) cnt += stc[j];
if(cnt >= need) ans = (ans + 1) % mod;
else {
if(cnt + num < need) continue;
else {
int p = (m - i + 1) * qmi(m, mod - 2, mod) % mod;
for(int k = need - cnt; k <= num; k ++) {
ans = (ans + c[num][k] * qmi(p, k, mod) % mod * qmi((1 - p + mod) % mod, num - k, mod) % mod) % mod;
}
}
}
}
cout << ans;
}