Endless Walk(拓扑排序、强连通分量)
题意
给定一张简单有向图,其中点的个数为\(n\),边的个数为\(m\)。
问有多少个点满足如下要求:从该点出发,能够永不停止地走下去。
数据范围
\(1 \leq n \leq 2 \times 10^5\)
\(0 \leq m \leq 2 \times 10^5\)
思路
这道题是个强连通分量的模板题,但是官方题解的做法更加简洁、更加高效,是非常值得学习。
不过无论是那种方法,首先要做的就是将永不停止转化为:不会走到尽头,也就是会走到环。
问题的答案就是环以及之前的所有点的个数。
-
方法一:强连通分量
建反图,然后求出所有的强连通分量。将所有点数不少于\(2\)的强连通分量中的点加入队列中,跑BFS。经过点的个数即为答案。 -
方法二:拓扑排序
建反图,跑拓扑排序,当没有入度为\(0\)的点时就停下来。\(n -\) 所有曾经出现在队列中的点数即为答案。
代码
- 强连通分量:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
const int N = 200010, M = 200010;
int n, m;
int h[N], e[M], ne[M], idx;
int dfn[N], low[N], timestamp;
int stk[N], top;
bool in_stk[N];
int id[N], scc_cnt, sz[N];
int d[N];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void tarjan(int u)
{
dfn[u] = low[u] = ++ timestamp;
stk[ ++ top] = u, in_stk[u] = true;
for (int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if (!dfn[j])
{
tarjan(j);
low[u] = min(low[u], low[j]);
}
else if (in_stk[j]) low[u] = min(low[u], dfn[j]);
}
if (dfn[u] == low[u])
{
++ scc_cnt;
int y;
do {
y = stk[top -- ];
in_stk[y] = false;
id[y] = scc_cnt;
sz[scc_cnt] ++ ;
} while (y != u);
}
}
int main()
{
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
while (m -- )
{
int a, b;
scanf("%d%d", &a, &b);
add(b, a);
}
for (int i = 1; i <= n; i ++ )
if (!dfn[i])
tarjan(i);
memset(d, 0x3f, sizeof d);
queue<int> que;
for(int i = 1; i <= n; i ++) {
if(sz[id[i]] > 1) {
que.push(i);
d[i] = 0;
}
}
while(que.size()) {
int t = que.front();
que.pop();
for(int i = h[t]; ~i; i = ne[i]) {
int j = e[i];
if(d[j] > d[t] + 1) {
d[j] = d[t] + 1;
que.push(j);
}
}
}
int ans = 0;
for(int i = 1; i <= n; i ++) {
if(d[i] != 0x3f3f3f3f) ans ++;
}
printf("%d\n", ans);
return 0;
}
- 拓扑排序:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
const int N = 200010, M = N;
int n, m;
int h[N], e[M], ne[M], idx;
int din[N];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
int main()
{
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
for(int i = 0; i < m; i ++) {
int a, b;
scanf("%d%d", &a, &b);
add(b, a);
din[a] ++;
}
queue<int> que;
for(int i = 1; i <= n; i ++) {
if(!din[i]) {
que.push(i);
}
}
int num = 0;
while(que.size()) {
num ++;
int t = que.front();
que.pop();
for(int i = h[t]; ~i; i = ne[i]) {
int j = e[i];
if(--din[j] == 0) {
que.push(j);
}
}
}
int ans = n - num;
printf("%d\n", ans);
return 0;
}