2022.10.3-D 宝石

题意:

初始有 \(n\) 种宝石,每种宝石有 \(1\) 颗。

现在你要进行 \(m\) 次操作,每次等概率选择一个宝石,将其复制一遍。

问最后数量最多的前 \(k\) 种宝石的期望数量。

\(n, m, k\le 500\)


思路

考虑构造一个单调不上升的序列 \(a_1, a_2, ..., a_n\)\(a_i\) 表示某一种宝石被操作的次数。

对这个序列计算它的贡献,则有:

\[(k+\sum_{i=1}^k a_i)\frac{\prod_{i=1}^n a_i!}{n^{\overline{m}}}\times \frac{m!}{\prod_{i=1}^n a_i!} \]

左边的分式计算的是这个序列单次的概率,右边的分式计算的是其对应的操作序列的个数。

\(n = 3, m = 3\),对于一种操作序列 \((2, 2, 3)\),其出现的概率为 \(\frac{1}{3}\times\frac{2}{4}\times\frac{1}{5}\)(对应左分式);而与它同类型的操作序列还有 \((2, 3, 2)\) \((3, 2, 2)\)(对应右分式)。

然后我们惊喜地发现,左右分式可以消掉一些项,就直接化成了:

\[(k+\sum_{i=1}^k a_i)\frac{m!}{n^{\overline{m}}} \]

好了,学会计算一种序列单次贡献,那么我们来计算一种序列对应多少种情况。

还是上面的例子,它对应的 \(a\) 序列为 \(2,1,0\),与它相同的操作序列还有 \((1,1,2)\) \((1,1,3)\) \((2,2,1)\) \((3,3,1)\) \((3,3,2)\) 等多种情况。

我们再构造一个序列 \(b_1, b_2, ..., b_m\)\(b_i\) 表示有多少种宝石被恰好操作了 \(i\) 次(可以理解为有多少个 \(a_k=i\))。

那么对于一种 \(a\) 序列,就有一下计算方式:

\[(k+\sum_{i=1}^k a_i)\frac{m!}{n^{\overline{m}}}\times\frac{n!}{\prod_{i=1}^m b_i!} \]

显然,我们可以将 \(m!,~n!,~n^{\overline{m}}\) 提取出来。

对于 \(\prod_{i=1}^m b_i!\),我们考虑用 DP 来计算:

我们设 \(f_{i,j,o}\) 表示我们考虑到同种宝石的操作次数 \(\ge i\),其中有 \(j\) 种宝石已经被选,共进行 \(o\) 次操作的情况。

我们枚举有 \(p\) 种宝石都是操作次数为 \(i\) 的,那么就有转移:

\[f_{i,j+p,o+i*p}\leftarrow \frac{f_{i,j,o}}{p!} \]

而对于 \(j<k,~j+p\ge k\) 的情况,这时候恰好越过了前 \(k\) 大的界限,这时候我们就计算期望,也就是有转移:

\[f_{i,j+p,o+i*p}\leftarrow \frac{f_{i,j,o}}{p!}\times((k-j)\times i+o+k) \]

最后的答案就是 \(f_{0,n,m}\times n!\times m!\times\frac{(n-1)!}{(n+m-1)!}\)


代码

#include<bits/stdc++.h>
#define LL long long
#define FOR(i, x, y) for(int i = (x); i <= (y); i++)
#define ROF(i, x, y) for(int i = (x); i >= (y); i--)
#define PFOR(i, x) for(int i = he[x]; i; i = r[i].nxt)
inline int rd()
{
    int sign = 1, re = 0; char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') sign = -1; c = getchar();}
    while('0' <= c && c <= '9'){re = re * 10 + (c - '0'); c = getchar();}
    return sign * re;
}
namespace MOD
{
    const int mod = 998244353;
    inline void add(int &a, int b) {a = a + b >= mod ? a + b - mod : a + b;}
    inline int mul(int a, int b) {return 1ll * a * b % mod;}
    inline int sub(int a, int b) {return a - b < 0 ? a - b + mod : a - b;}
    inline int fast_pow(int a, int b = mod - 2)
    {
        int re = 1;
        while(b)
        {
            if(b & 1) re = mul(re, a);
            a = mul(a, a);
            b >>= 1;
        }
        return re;
    }
} using namespace MOD;
int n, m, kth;
int fac[1005], ifac[1005];
inline void Init()
{
    int N = 1000;
    fac[0] = fac[1] = ifac[0] = ifac[1] = 1;
    FOR(i, 2, N) fac[i] = mul(fac[i - 1], i);
    ifac[N] = fast_pow(fac[N]);
    ROF(i, N - 1, 2) ifac[i] = mul(ifac[i + 1], i + 1);
}
int f[505][505][505];
signed main()
{
    Init();
    n = rd(), m = rd(), kth = rd();
    f[m + 1][0][0] = 1;
    ROF(i, m, 0) FOR(j, 0, n) FOR(k, 0, m) if(f[i + 1][j][k])
        for(int p = 0; j + p <= n && k + i * p <= m; p++)
        {
            if(j < kth && j + p >= kth)
                add(f[i][j + p][k + i * p], mul(mul(f[i + 1][j][k], ifac[p]), k + (kth - j) * i + kth));
            else
                add(f[i][j + p][k + i * p], mul(f[i + 1][j][k], ifac[p]));
        }
    int ans = mul(f[0][n][m], mul(fac[m], fac[n]));
    ans = mul(ans, mul(ifac[n + m - 1], fac[n - 1]));
    printf("%d", ans);
    return 0;
}
posted @ 2022-10-15 09:10  zuytong  阅读(18)  评论(0编辑  收藏  举报