题解 P5326【[ZJOI2019] 开关】/ SS241023B【判断题】
已经沦落为可以随便搬进模拟赛的模板题了。。。
题目描述
当前有 \(n\) 道判断题,初始全选的错。初始给出每道题的正确答案,设 \(0\) 表示错,\(1\) 表示对。每道题有一个参数 \(p_i\),每轮会以 \(\frac {p_i} {\sum_{j = 1} ^ {n} p_j}\) 的概率选择第 \(i\) 道题并修改(flip)这道题的答案。问期望经过多少轮以后,\(n\) 道题与正确答案完全一致,答案对 \(998244353\) 取模。记 \(a_i\) 表示第 \(i\) 道题的正确答案。
保证对于所有数据满足:\(1 \leq n \leq 100\),\(\sum p_i \leq 5 \times 10 ^ 5\),\(p_i \geq 1\)。
概率生成函数基础理论
使用 PGF(概率生成函数)刻画原问题。PGF 是一个无穷项的 OGF(注意是 OGF),其 \([z^i]\) 项系数表示 \(i\) 步达成某种目标的概率,表示为 \(F(z)=\sum_{i\geq 0}f_ix^i\)。若已知达成某种目标(事件)的 PGF 为 \(F(z)=\sum_{i\geq 0}f_ix^i\),则达成这种目标的期望步数为 \(\sum_{i\geq 0}f_i\cdot i=F'(1)\)。
令
- 从初态到终态的 PGF 为 \(F(z)\);
- 从终态到终态的 PGF 为 \(G(z)\);
- 从初态首次到终态的 PGF 为 \(H(z)\)(此前 \(F, G\) 均没有强调首次,表示可以经过很多次)。
则可以得知:\(H(z)\cdot G(z)=F(z)\) 即 \(H(z)=\dfrac{F(z)}{G(z)}\)。我们要求的从初态首次到达终态的期望步数即为
以上都是理想情况,现实中 \(F(z), G(z)\) 不一定那么好算,甚至 \(H(1)\) 的计算也可能需要动手脚,但是分式求导几乎总会出现。很多情况下(几乎任何情况下都)会出现 \(F(1), G(1)\) 不收敛的情况,但因为答案是收敛的,此时一般可以修改 \(F(z), G(z)\),使它们额外乘上 \((1-z)\) 的因式,以去除某些值为 \(0\) 的分母。
拉普拉斯算子
有时需要将 EGF 转换为 OGF。记 \(\hat F(z)\) 为 \(F(z)\) 的 EGF 形式(即 \(F(z)=\sum_{i\geq 0}f_ix^i, \hat F(z)=\sum_{i\geq 0}f_ix^i/i!\))。前人从不知道哪个地方偷了个称作“拉普拉斯算子”(Laplace Operator)的符号 \(\mathscr L\)(\mathscr L
)拿来用,并赋予其新定义 \(\mathscr L\hat F(z)=F(z)\),实现 EGF 向 OGF 的转化。可以感受到
- \(\mathscr L\) 是线性算子,\(\mathscr L(a\hat F(z)+b\hat G(z))=a\mathscr L\hat F(z)+b\mathscr L\hat G(z)\)。
- \(\mathscr Lae^{bz}=\dfrac{a}{1-bz}\),因为它们都指名同一个序列 \(\{a, ab, ab^2, ab^3, ab^4,\cdots\}\),前者 \(ae^{bz}\) 是 EGF 形式,后者 \(\dfrac{a}{1-bz}\) 是 OGF 形式。
solution
沿用上文记号,但是 OGF 不好刻画某个步数下的操作序列,EGF 才好刻画(我们需要 EGF 作为二项卷积的载体)。
限制即为第 \(i\) 道题目要被修改的次数 \(\equiv a_i\pmod 2\)。记 \(s=\sum_{i=1}^np_i\)。不演了,直接写 \(\hat F(z), \hat G(z)\),希望读者自己说明其意义:
由于 \(O(ns)\) 的复杂度可接受,我们可以暴力计算并展开 \(\hat F(z), \hat G(z)\) 写为
的形式,这样的好处是我们可以将这个 EGF 转为 OGF。
这样以后我们就可以计算 \(H'(1)\) 了,需要计算 \(F(1), G(1), F'(1), G'(1)\)。但正如你所见当 \(z=1,w=s\) 时 \(\frac{a_w}{1-wz/s}\) 的分母为 \(0\),十分滑稽。这里的修正方法是使 \(F(z), G(z)\) 都乘 \((1-z)\) 使刚才说的那一项变为 \(a_s\) 这个常数,那么就可以计算了。这里需要对 \(F(z), G(z)\) 求导,其实也是一样用求导除法公式,参见上文 \(H'(z)\) 部分,这里需要你用草稿纸算好式子再写到程序里。
复杂度 \(O(ns)\)。
code
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
template <unsigned umod>
struct modint {// {{{
static constexpr int mod = umod;
unsigned v;
modint() = default;
template <class T, enable_if_t<is_integral<T>::value, int> = 0>
modint(const T& _v): v((unsigned)(_v % mod + (_v < 0 ? mod : 0))) {}
modint& operator+=(const modint& rhs) { v += rhs.v; if (v >= umod) v -= umod; return *this; }
modint& operator-=(const modint& rhs) { v -= rhs.v; if (v >= umod) v += umod; return *this; }
modint& operator*=(const modint& rhs) { v = (unsigned)(1ull * v * rhs.v % umod); return *this; }
modint& operator/=(modint rhs) {
assert(rhs.v);
static constexpr int lim = 1 << 20;
while (rhs.v >= lim) *this *= mod - mod / rhs.v, rhs.v = umod % rhs.v;
static vector<modint> inv{0, 1};
while (rhs.v >= inv.size()) {
auto m = inv.size();
inv.resize(m << 1);
for (auto i = m; i < m << 1; i++) inv[i] = (mod - mod / i) * inv[mod % i];
}
return *this *= inv[rhs.v];
}
friend modint operator+(modint lhs, const modint& rhs) { return lhs += rhs; }
friend modint operator-(modint lhs, const modint& rhs) { return lhs -= rhs; }
friend modint operator*(modint lhs, const modint& rhs) { return lhs *= rhs; }
friend modint operator/(modint lhs, const modint& rhs) { return lhs /= rhs; }
explicit operator bool() const { return v != 0; }
friend int raw(const modint& self) { return self.v; }
friend ostream& operator<<(ostream& os, const modint& self) { return os << raw(self); }
friend modint qpow(modint a, LL b) {
modint r = 1;
for (; b; b >>= 1, a *= a) if (b & 1) r *= a;
return r;
}
};// }}}
using mint = modint<998244353>;
int n, s[110], p[110], ph;
mint* alloc(int siz) {
static mint _[10000010], *_ptr = _;
return (_ptr += siz << 1 | 1) - siz;
}
mint *a = alloc(5e4), *b = alloc(5e4), *tmp = alloc(5e4);
int main() {
#ifndef LOCAL
cin.tie(nullptr)->sync_with_stdio(false);
#endif
cin >> n;
for (int i = 1; i <= n; i++) cin >> s[i];
for (int i = 1; i <= n; i++) cin >> p[i], ph += p[i];
a[0] = 1, b[0] = 1;
for (int i = 1; i <= n; i++) {
for (int j = -ph; j <= ph; j++) tmp[j] = exchange(a[j], 0);
// a: exp(p[i] / ph)
for (int j = ph; j - p[i] >= -ph; j--) a[j] += tmp[j - p[i]];
// a: (-1)^s[i] * exp(-p[i] / ph)
for (int j = -ph; j + p[i] <= ph; j++) a[j] += tmp[j + p[i]] * (s[i] ? -1 : 1);
for (int j = -ph; j <= ph; j++) tmp[j] = exchange(b[j], 0);
// b: exp(p[i] / ph)
for (int j = ph; j - p[i] >= -ph; j--) b[j] += tmp[j - p[i]];
// b: exp(-p[i] / ph)
for (int j = -ph; j + p[i] <= ph; j++) b[j] += tmp[j + p[i]];
}
mint f1 = a[ph], g1 = b[ph], fd1 = 0, gd1 = 0;
for (int j = -ph; j <= ph - 1; j++) fd1 += a[j] / (mint{j} / ph - 1);
for (int j = -ph; j <= ph - 1; j++) gd1 += b[j] / (mint{j} / ph - 1);
cout << (fd1 * g1 - gd1 * f1) / (g1 * g1) << endl;
return 0;
}
本文来自博客园,作者:caijianhong,转载请注明原文链接:https://www.cnblogs.com/caijianhong/p/18498434/solution-P5326