序列合并
Problem
有一个序列,初始时为空,你会不断往序列末尾添加一个 \([1, m]\) 的随机整数。
任意时刻
-
若序列末尾两个数相同(记为 \(x\)),且小于 \(t\),则这两个数会合并成 \(x+1\);
-
若序列长度为 \(n\) 且无法合并,则操作结束。
求序列中所有元素和的期望,答案对 \(10^9+7\) 取模。
\(1\leq n,m\leq 10^3\),\(m\leq t\leq 10^9\)
Solution
记 \(L=\min\{t, n+m-1\}\)。
\(p_{i,j}\):限制序列长度为 \(i\),第一个位置出现 \(j\) 的概率。
\[p_{i,j}=[j\leq m]\frac1m+p_{i,j-1}\times p_{i-1,j-1}
\]
\(q_{i,j}\):限制序列长度为 \(i\),第一个位置为 \(j\) 下,之后不再改变的概率。
\[q_{i,j}=1-[j<L]p_{i-1,j}
\]
\(g_{i,j}\):限制序列长度为 \(i\),第一个位置为 \(j\),且之后不再改变后,整个序列的期望。
\(ans_i\):序列长度为 \(i\) 时,权值和的期望。
\(f_{i,j}\):序列长度为 \(i\),第一个数字出现 \(j\) 时,序列元素权值和。
\(j<L\) 时,
\[\begin{aligned}
g_{i,j}&=j+\sum_{S}S\times\Pr\{第2到i权值和为S|第1个为j,且j不变\}\\
&=j+\frac{\sum_{S}S\times\Pr\{第2到i权值和为S,j不改变|第1个为j\}}{\Pr\{j不改变|第1个为j\}}\\
&=j+\frac{\sum_{S}S\times\Pr\{第2到i权值和为S,j不改变|第1个为j\}}{q_{i,j}}\\
&=j+\frac{\sum_{S}S\times\Pr\{第2到i权值和为S,j任意|第1个为j\}-\sum_{S}S\times\Pr\{第2到i权值和为S,j改变|第1个为j\}}{q_{i,j}}\\
&=j+\frac{ans_{i-1}-\Pr\{第2个为j\}\times\sum_S S\times\Pr\{第2到i权值和为S|第2个为j\}}{q_{i,j}}\\
&=j+\frac{ans_{i-1}-p_{i-1,j}\times f_{i-1,j}}{q_{i,j}}
\end{aligned}
\]
\(j=L\) 时,
\[g_{i,j}=j+ans_{i-1}
\]
对于 \(ans_i\):
\[ans_i=\sum_{j=1}^Lp_{i,j}\times q_{i,j}\times g_{i,j}
\]
对于 \(f_{i,j}\):
\[f_{i,j}=q_{i,j}\times g_{i,j}+(1-q_{i,j})\times f_{i,j+1}
\]
为了避免求逆元,将 \(g_{i,j}\times q_{i,j}\) 整体转移。复杂度 \(\mathcal O(nm)\)。
Code
#include <bits/stdc++.h>
const int N = 2005, P = 1e9 + 7;
using std::min;
int n, m, t, L, f[N][N], qg[N][N], p[N][N], q[N][N], ans[N];
int qpow(int a, int b) {
int t = 1;
for (; b; b >>= 1, a = 1LL * a * a % P)
if (b & 1) t = 1LL * t * a % P;
return t;
}
int main() {
scanf("%d%d%d", &n, &m, &t); L = min(n + m - 1, t);
int inv = qpow(m, P-2);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= L; j++) {
p[i][j] = ((j <= m ? inv : 0) + 1LL * p[i][j - 1] * p[i - 1][j - 1]) % P;
q[i][j] = (1 - (j < L ? p[i - 1][j] : 0) + P) % P;
}
for (int j = L; j; j--) {
qg[i][j] = (1LL * q[i][j] * j + ans[i - 1] - (j < L ? 1LL * p[i - 1][j] * f[i - 1][j] % P : 0) + P) % P;
f[i][j] = (qg[i][j] + (j < L ? 1LL * (1 - q[i][j] + P) * f[i][j + 1] : 0)) % P;
ans[i] = (ans[i] + 1LL * p[i][j] * qg[i][j]) % P;
}
}
printf("%d\n", ans[n]);
return 0;
}