Loading

【题解】CF850F Rainbow Balls

整体方向很常规,但是最后的处理比较仙,记一下。

思路

期望 dp.

首先意识到最终会变成同一种颜色,并且不同颜色的期望步数不同。

考虑到 \(n \leq 2.5 \times 10^3\),考虑钦定最终的颜色。

假设钦定的颜色为 \(c\),令 \(f[i]\) 为当前已有 \(i\) 个颜色为 \(c\) 的球时,令所有球颜色变成 \(c\) 需要的最小步数。

转移简单分讨一下选同色球和选异色球的情况,所以直觉上的转移是这样的。

\(f[i] = (f[i - 1] + f[i + 1]) p + (f[i] + 1)\),其中 \(p\) 是选到异色球的概率。

令球的总数为 \(m\),可以推出 \(p = \frac{i (m - i)}{m (m - 1)}\),所以原方程等价于 \(f[i] = (f[i - 1] + f[i + 1]) \frac{i (m - i)}{m (m - 1)} + (f[i] + 1)\).

上面这种转移是一种误区,原因是期望 dp 转移对边界的依赖。

考虑这个 dp 的边界情况,\(f[m] = 0\),但是 \(f[0]\) 无法定义:因为没有颜色为 \(c\) 的球时无法通过操作获得颜色为 \(c\) 的球。

一般情况下的 dp 都会根据题目需要把这个值设成 \(0\) 或者 \(\pm \infty\),但是期望 dp 的转移一般是通过化简相邻的 \(f\) 之间的关系得到的,因此强行给 \(f[0]\) 赋值会导致后面的 dp 都是错的。

所以我们考虑绕过 \(f[0]\) 转移。

\((f[i - 1] + f[i + 1]) p\) 这一项符合期望的定义,不用修正。需要考虑的是 \(f[i] + 1\) 这一项:在球的个数不能为 \(0\) 的前提下,考虑使 \(i\) 变成 \(m\) 的期望步数。

这个问题放在数轴上是一个经典的问题。设 \(g[i]\) 表示从 \(i\) 到达 \(m\) 的概率。因为当前点等概率向左右走,所以 \(g[i] = (g[i - 1] + g[i + 1]) p + (1 - 2p) g[i]\),这里 \(p\) 的定义和上面相同。

因为 \(p[0] = 0, p[s] = 1\),所以用 \(0\) 转移没有上面的那种影响。

上式化简得 \(g[i] - g[i - 1] = g[i + 1] - g[i]\),所以 \(g[i] = \frac{i}{m}\).

考虑代回 \(f\) 的转移方程,用 \(f[i]\) 转移的概率是 \(\frac{i}{m}\),代价是 \(1\),所以有:\(f[i] = (f[i - 1] + f[i + 1]) p + f[i] \frac{i}{m}\).

化简得到 \(f[i] - f[i + 1] = f[i - 1] - f[i] + \frac{m - 1}{m - i}\).

边界情况特殊考虑。因为 \(f[0]\) 不存在,所以 \(f[1] = f[2] p + (1 - 2p) f[1] + \frac{1}{m}\),也就是 \(f[2] = 2 f[1] - 1\).

所以 \(f[1] = f[1] - f[m] = \sum\limits_{i = 2}^m f[i - 1] - f[i] = (m - 1) (f[1] - f[2]) + \sum\limits_{i = 2}^{m - 1} \frac{m - 1}{m - i} (m - i)\)

又因为 \(f[2] = 2 f[1] - 1\),所以 \(f[1] = \frac{(m - 1)^2}{m}\).

处理出 \(f[1], f[2]\) 就可以按照递推式求 \(f\) 了。答案是 \(\sum\limits_{i = 1}^n f[a_i]\).

递推式整理出来是 \(f[i] = (2 f[i - 1] - f[i - 2]) - \frac{m - 1}{m - i}\).

时间复杂度 \(O(n \log n)\),也可以线性求逆元做到 \(O(n)\).

另外有神仙用停时和鞅怒草此题,不太懂,但感觉好强。

代码

#include <cstdio>
using namespace std;

const int maxn = 1e5 + 5;
const int mod = 1e9 + 7;

int n, m;
int a[maxn], f[maxn];

inline int max(const int &a, const int &b) { return (a >= b ? a : b); }
inline void add(int &x, int y) { if ((x += y) >= mod) x -= mod; }

int qpow(int base, int power)
{
    int res = 1;
    while (power)
    {
        if (power & 1) res = 1ll * res * base % mod;
        base = 1ll * base * base % mod;
        power >>= 1;
    }
    return res;
}

int main()
{
    scanf("%d", &n);
    int lim = 0;
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]), m += a[i], lim = max(lim, a[i]);
    f[1] = 1ll * (m - 1) * (m - 1) % mod * qpow(m, mod - 2) % mod;
    f[2] = (2 * f[1] % mod - 1 + mod) % mod;
    for (int i = 2; i <= lim; i++) f[i + 1] = ((2 * f[i] % mod - f[i - 1] + mod) % mod - 1ll * (m - 1) * qpow(m - i, mod - 2) % mod + mod) % mod;
    int ans = 0;
    for (int i = 1; i <= n; i++) ans = (ans + f[a[i]]) % mod;
    printf("%d\n", ans);
    return 0;
}
posted @ 2023-02-09 18:33  kymru  阅读(17)  评论(0编辑  收藏  举报