HDU7016 Random Walk 2
传送门
这题真可惜,比赛的时候最后10分钟推出来了,但是10分钟写不完高斯消元啊。
这题思路还挺多,有题解的一种,还有我们队里大佬的另一种,以及我自己的一种。但是其他两个思路我都不是很懂,遂只能将自己的思路记录于此了。
这题我就是按照图上的随机游走模型去做的。
我们先把起点选择在\(1\)号点。令\(E(u)\)表示走到\(u\),且能继续走下去的概率,\(f(u)\)表示走到\(u\),且停在\(u\)的概率。能继续走下去就意味着上一时刻不能在\(u\),停住就意味着上一时刻必须在\(u\).
那么可以列出关系式:
\(F(u)=E(u) * P(u,u) + P(1,1)*(u == 1)\),
\(E(u)=\sum\limits_{v = 1,v \neq u} ^ {n} E(v) * P(v, u) + P(1,u) * (u \neq 1)\).
注意后面的常数项,因为刚开始在\(1\)号点,所以有\(P(1,1)\)的概率停住,有\(P(1,u)\)的概率走到别的点。(就因为这些常数项我没整明白,推不对又放弃了)
\(E(u)\)之间的关系可以用高斯消元来解,代入就能求得\(F(u)\)了。
上面是起点在\(1\)的情况,那么对于起点在\(x\)的情况,我们发现系数矩阵是一样的,只有增广矩阵的最右边一列是不一样的,因此可以将增广矩阵写成\(n+n\)列,\(n\)个方程组同时解。那么时间复杂度就是\(O(n * n * 2n)\).
当然,也可以把矩阵方程列出来然后求逆,效果一样。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<queue>
#include<assert.h>
#include<ctime>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const ll mod = 998244353;
const int maxn = 302;
In ll read()
{
ll ans = 0;
char ch = getchar(), las = ' ';
while(!isdigit(ch)) las = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(las == '-') ans = -ans;
return ans;
}
In void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
In ll ADD(ll a, ll b) {return a + b < mod ? a + b : a + b - mod;}
In ll quickpow(ll a, ll b)
{
ll ret = 1;
for(; b; b >>= 1, a = a * a % mod)
if(b & 1) ret = ret * a % mod;
return ret;
}
int n, m;
ll p[maxn][maxn];
ll f[maxn][maxn << 1];
In void Gauss()
{
for(int i = 1; i <= n; ++i)
{
int pos = i;
while(pos <= n && !f[pos][i]) ++pos;
if(pos > n) continue;
if(pos > i) swap(f[i], f[pos]);
ll inv = quickpow(f[i][i], mod - 2);
for(int j = i; j <= m; ++j) f[i][j] = f[i][j] * inv % mod;
for(int j = i + 1; j <= n; ++j)
{
ll tp = f[j][i];
for(int k = i; k <= m; ++k) f[j][k] = ADD(f[j][k], mod - tp * f[i][k] % mod);
}
}
for(int i = n; i; --i)
for(int j = i - 1; j; --j)
for(int k = n + 1; k <= m; ++k)
f[j][k] = ADD(f[j][k], mod - f[j][i] * f[i][k] % mod);
}
In void solve()
{
Mem(f, 0); m = n + n;
for(int i = 1; i <= n; ++i)
{
ll sum = 0;
for(int j = 1; j <= n; ++j) sum += p[i][j];
sum = quickpow(sum, mod - 2);
for(int j = 1; j <= n; ++j) p[i][j] = p[i][j] * sum % mod;
}
for(int i = 1; i <= n; ++i)
{
f[i][i] = 1;
for(int j = 1; j <= n; ++j)
if(i ^ j)
{
f[i][j] = mod - p[j][i];
f[j][n + i] = ADD(f[j][n + i], p[i][j]);
}
}
Gauss();
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= n; ++j)
{
write(ADD(f[j][n + i] * p[j][j] % mod, (i == j) * p[j][j]));
j == n ? enter : space;
}
}
int main()
{
int T = read();
while(T--)
{
n = read();
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= n; ++j) p[i][j] = read();
solve();
}
return 0;
}