[lnsyoj4079/luoguP6899]Pachinko
题意
一个包括空地、障碍和空洞的 \(H\times W\) 地图,从第一列随机选取空地作为起始位置,到达某个空洞后停止运动。给定向上下左右移动的概率比 \(p_u:p_d:p_l:p_r\),求在每个空洞停止运动的概率为多少。
sol
由于每次到达空洞后即会停止运动,因此到达空洞的期望次数即为到达每个空洞的概率。记 \(f_{i,j}\) 表示到达 \((i,j)\) 的次数的期望,就可以很容易的推出转移方程:
(\(stcnt\) 为第一行空地的数量,\(psum\) 为可以转移到的方向的概率比之和)
需要注意一点:当出现某一列全为 \(0\) 无法消元时,对答案并无影响,因此 continue
即可。
这样,我们只需要对列出的方程进行高斯消元即可解出答案……吗?
我们发现,这样一来,我们所求的未知数数量为 \(HW\) 个,而高斯消元的复杂度是 \(O(n^3)\) 的,因此时间复杂度就达到了 \(O((HW)^3)\),空间复杂度达到了 \(O((HW)^2)\),无法接受。
不过,我们的增广矩阵中,最多只会在对角线左右 \(\pm W\) 的位置出现非零数字,这是一个非常经典的优化 band-matrix
。在消元和回代时,只向后修改 \(W\) 列即可,即使是矩阵中出现 \(0\),需要换行时,也最多只需修改 \((2W)\times (2W)\) 的矩阵即可,而不需要修改整个矩阵;而且,我们可以为每一个可能出现非零数字的位置记录一个偏移量,这样时间复杂度就优化到了 \(O(HW^3)\),空间复杂度优化到了 \(O(HW^2)\)。
注意:如果 \(\color{red}{\text{WA}}\) on #9, Line 13,说明是被卡精度了;如果结果输出了 -nan
,请改变写法,不要直接使用 \(i*m+j\) 作为编号存储,而是对每一个非障碍方格进行编号。
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;
const int N = 25, M = 10005;
const long double eps = 1e-10;
int n, m, p[4];
int dx[4] = {-1, 1, 0, 0}, dy[4] = {0, 0, -1, 1};
int pfrom[4] = {1, 0, 3, 2};
long double coeff[N * M][N * 4];
long double ans[N * M];
int psum[N * M];
int id[M][N];
char g[M][N];
int idx;
long double &c(int a, int b){
return coeff[a][b - max(0, a - 2 * m)];
}
void gauss(){
for (int i = 1; i <= idx; i ++ ){
if (abs(c(i, i)) <= eps) continue;
int ed = min(idx, i + 2 * m);
ans[i] /= c(i, i);
for (int j = ed; j >= i; j -- )
c(i, j) /= c(i, i);
for (int j = i + 1; j <= min(i + m, idx); j ++ ){
ans[j] -= c(j, i) * ans[i];
for (int k = ed; k >= i; k -- ) c(j, k) -= c(j, i) * c(i, k);
}
}
for (int i = idx; i; i -- ){
for (int j = i + 1; j <= min(idx, i + 2 * m); j ++ )
ans[i] -= c(i, j) * ans[j];
}
}
int main(){
scanf("%d%d", &m, &n);
for (int i = 0; i < 4; i ++ ) scanf("%d", &p[i]);
for (int i = 0; i < n; i ++ ) scanf("%s", g[i]);
int stcnt = 0;
for (int i = 0; i < m; i ++ ) stcnt += (g[0][i] == '.');
for (int i = 0; i < n; i ++ )
for (int j = 0; j < m; j ++ ){
id[i][j] = (g[i][j] == 'X') ? -1 : ++ idx;
if (id[i][j] == -1) continue;
for (int u = 0; u < 4; u ++ ) {
int sx = i + dx[u], sy = j + dy[u];
if (sx < 0 || sx >= n || sy < 0 || sy >= m || g[sx][sy] == 'X') continue;
psum[id[i][j]] += p[u];
}
}
for (int i = 0; i < n; i ++ )
for (int j = 0; j < m; j ++ ){
if (id[i][j] == -1) continue;
if (!i && g[i][j] == '.') ans[id[i][j]] = (long double) 1 / stcnt;
for (int u = 0; u < 4; u ++ ){
int sx = i + dx[u], sy = j + dy[u];
int px = pfrom[u];
if (sx < 0 || sx >= n || sy < 0 || sy >= m || g[sx][sy] == 'X' || g[sx][sy] == 'T') continue;
int sid = id[sx][sy];
c(id[i][j], sid) = (long double) -p[px] / psum[sid];
}
c(id[i][j], id[i][j]) = 1;
}
gauss();
for (int i = 0; i < n; i ++ )
for (int j = 0; j < m; j ++ )
if (g[i][j] == 'T') printf("%.9Lf\n", ans[id[i][j]]);
return 0;
}