[SDOI2012] 走迷宫 题解
前言
题目链接:洛谷;Hydro & bzoj。
题意简述
有向图中,求起点到终点的期望步数。若期望不存在,输出 INF
。
保证强连通分量的大小不超过 \(100\)。
题目分析
首先来想想什么情况下期望不存在。很显然是,从起点能走到一个点,而该点永远走不到终点,当然,是在走到终点马上停下的前提下。转化一下,就是从起点开始 BFS,如果遇到一个点,没有出度,并且这个点还不是终点,那么这种情况下应该输出 INF
。但是,很多人都想错了这一步判 INF
,详见此帖。
代码实现起来很简单:
bool vis[10010];
bool check() {
queue<int> Q;
Q.push(S), vis[S] = true;
bool can = false;
while (!Q.empty()) {
int now = Q.front(); Q.pop();
if (now == T) {
can = true;
continue;
}
if (!xym.head[now]) return false;
for (int i = xym.head[now], to; to = xym[i].to, i; i = xym[i].nxt) {
if (vis[to]) continue;
vis[to] = true;
Q.push(to);
}
}
return can;
}
接下来考虑如何求期望步数。
一个套路的想法,记 \(f_i\) 为从 \(i\) 到终点的期望步数,边界 \(f_t = 0\),答案就是 \(f_s\)。转移就是在出边里等概率选择一条边。
由于存在环形转移,所以使用高斯消元解方程组就行了。注意到,在预处理增广矩阵的时候,对于 \(xym \to yzh\) 这条边,如果 \(xym = t\),就不做处理;如果在之前 check
的时候没走到过 \(xym\),即 \(\operatorname{vis}[xym] = \text{false}\),也不要添加到矩阵里。
由于时间复杂度 \(\Theta(n^3)\),能拿到 \(70\) 分。考虑如何优化。
注意到之所以要用高斯消元,是因为存在环形转移。如果是在序列上,或者换句话说,在一个 DAG 上,我们直接 DP 就行了。所以考虑用 tarjan 缩点,强联通分量里高斯消元,分量外拓扑排序直接期望 DP。这么做正确性体现在题目中保证强连通分量的大小不超过 \(100\),时间复杂度 \(\Theta(\sum siz^3) \leq \mathcal{O}(n\max ^ 2siz)\)。
具体地,我们反向跑拓扑。对于当前强联通分量里的每一个点连出的边,如果对方不是同一个强联通分量,则已经被我们计算过了,加到右边常数里;反之处理到左边系数矩阵里。
代码
挺快的,卡卡常最优解。
// #pragma GCC optimize(3)
// #pragma GCC optimize("Ofast", "inline", "-ffast-math")
// #pragma GCC target("avx", "sse2", "sse3", "sse4", "mmx")
#include <iostream>
#include <cstdio>
#define debug(a) cerr << "Line: " << __LINE__ << " " << #a << endl
#define print(a) cerr << #a << "=" << (a) << endl
#define file(a) freopen(#a".in", "r", stdin), freopen(#a".out", "w", stdout)
#define main Main(); signed main() { return ios::sync_with_stdio(0), cin.tie(0), Main(); } signed Main
using namespace std;
#include <algorithm>
#include <queue>
#include <vector>
#include <cstring>
template <size_t N, size_t M>
int guass(int, int, double [M][N], double [M], double [N]);
const double eps = 1e-10;
int n, m, S, T;
int U[1000010], V[1000010];
int du[10010];
struct Graph{
struct node{
int to, nxt;
} edge[1000010 << 1];
int eid, head[10010];
inline void add(int u, int v){
edge[++eid] = {v, head[u]};
head[u] = eid;
}
inline node & operator [] (const int x){
return edge[x];
}
} xym, yzh;
bool vis[10010];
bool check() {
queue<int> Q;
Q.push(S), vis[S] = true;
bool can = false;
while (!Q.empty()) {
int now = Q.front(); Q.pop();
if (now == T) {
can = true;
continue;
}
if (!xym.head[now]) return false;
for (int i = xym.head[now], to; to = xym[i].to, i; i = xym[i].nxt) {
if (vis[to]) continue;
vis[to] = true;
Q.push(to);
}
}
return can;
}
int dfn[10010], low[10010], timer;
int sccno[10010], scc_cnt;
int stack[10010], top;
bool in_stack[10010];
vector<int> scc[10010];
int whr[10010];
double key[110][110], val[110], res[10010][110];
void tarjan(int now) {
dfn[now] = low[now] = ++timer, in_stack[stack[++top] = now] = true;
for (int i = xym.head[now]; i; i = xym[i].nxt) {
int to = xym[i].to;
if (dfn[to] == 0) tarjan(to), low[now] = min(low[now], low[to]);
else if (in_stack[to]) low[now] = min(low[now], dfn[to]);
}
if (low[now] == dfn[now]){
++scc_cnt;
do {
int now = stack[top--];
in_stack[now] = false;
sccno[now] = scc_cnt;
scc[scc_cnt].push_back(now);
whr[now] = scc[scc_cnt].size();
} while (stack[top + 1] != now);
}
}
signed main() {
scanf("%d%d%d%d", &n, &m, &S, &T);
for (int i = 1, u, v; i <= m; ++i) {
scanf("%d%d", &u, &v);
xym.add(u, v);
yzh.add(v, u);
++du[u];
U[i] = u, V[i] = v;
}
if (!check()) return puts("INF"), 0;
tarjan(S);
for (int i = 1; i <= scc_cnt; ++i) {
int siz = scc[i].size();
for (int j = 1; j <= siz; ++j) {
memset(key[j], 0x00, sizeof (double) * (siz + 1));
key[j][j] = 1;
val[j] = scc[i][j - 1] != T;
}
for (const auto& u: scc[i]) if (u != T)
for (int _ = xym.head[u], v; v = xym[_].to, _; _ = xym[_].nxt) {
if (sccno[u] == sccno[v]) {
key[whr[u]][whr[v]] -= 1.0 / du[u];
} else {
val[whr[u]] += 1.0 / du[u] * res[sccno[v]][whr[v]];
}
}
guass<110, 110>(siz, siz, key, val, res[i]);
}
printf("%.3lf", res[sccno[S]][whr[S]]);
return 0;
}
template <size_t N, size_t M>
int guass(int n, int m, double key[M][N], double val[M], double res[N]) {
if (m < n) return -1;
int nline = 1;
for (int i = 1; i <= n; ++i) {
int whr = nline;
for (int j = nline + 1; j <= m; ++j)
if (abs(key[j][i]) > abs(key[whr][i]))
whr = j;
if (abs(key[whr][i]) < eps) continue;
swap(val[nline], val[whr]);
for (int j = 1; j <= n; ++j) swap(key[nline][j], key[whr][j]);
for (int j = 1; j <= m; ++j) if (j != nline) {
double K = key[j][i] / key[nline][i];
val[j] -= K * val[nline];
for (int k = i; k <= n; ++k)
key[j][k] -= key[nline][k] * K;
}
++nline;
}
if (nline == n + 1) {
for (int i = 1; i <= n; ++i)
res[i] = val[i] / key[i][i] + eps;
return 0;
}
for (int i = nline; i <= m; ++i)
if (abs(val[i]) > eps)
return -2;
for (int i = 1; i <= nline; ++i)
res[i] = val[i] / key[i][i] + eps;
return -1;
}
本文作者:XuYueming,转载请注明原文链接:https://www.cnblogs.com/XuYueming/p/18317591。
若未作特殊说明,本作品采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可。