「CodePlus 2017 11 月赛」大吉大利,晚上吃鸡!
思路
如果只要求一个点,那么就是求最短路上的必经点,有很多方法可以实现(比如某波说的建 DAG 跑 tarjan 求割点)
但现在是两个点,我们考虑将问题转化
要所有最短路都经过 \(A\) 或 \(B\),那么就是说,经过 \(A\) 的最短路方案数,加上经过 \(B\) 的最短路方案数,要等于所有最短路的方案数
(为什么不用减去同时经过 \(A,B\) 的最短路方案数,是因为第二个条件的限制)
而要不存在任意一条最短路同时经过 \(A,B\),那么就是在最短路的 DAG 中,不存在 \(A\) 到 \(B\) 的路径
对于第一个限制,我们可以跑经过点 \(u\) 的最短路方案数 \(f_u\),然后在 \(f_u\) 的 bitset 中,将 \(u\) 的位置设为 \(1\)(用 map 辅助实现)
对于第二个限制,我们分别从 \(S,T\) 出发,跑两次传递闭包,算出每个点 \(u\) 能到达的点,并在 \(u\) 的 bitset 中标记为 \(1\),然后将它取反
最后统计答案时,将两个 bitset 取并,计算个数
最后一定要记得,如果 \(S,T\) 不连通,那么要输出 \(\frac{n(n-1)}{2}\)(虽然我觉得这很不符合题意)
代码
#include<iostream>
#include<fstream>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<queue>
#include<unordered_map>
#include<set>
#include<bitset>
#define LL long long
#define FOR(i, x, y) for(int i = (x); i <= (y); i++)
#define ROF(i, x, y) for(int i = (x); i >= (y); i--)
#define PFOR(i, x) for(int i = he[x]; i; i = r[i].nxt)
inline int rd()
{
int sign = 1, re = 0; char c = getchar();
while(c < '0' || c > '9'){if(c == '-') sign = -1; c = getchar();}
while('0' <= c && c <= '9'){re = re * 10 + (c - '0'); c = getchar();}
return sign * re;
}
const LL INF = 1e15;
int n, m, S, T;
struct Node
{
int to, nxt, w;
}r[100005]; int he[50005];
inline void Edge_add(int u, int v, int w)
{
static int cnt = 0;
r[++cnt] = (Node){v, he[u], w};
he[u] = cnt;
}
LL dis[50005];
#define pli std::pair<LL, int>
std::priority_queue<pli, std::vector<pli>, std::greater<pli>> q;
std::bitset<50005> vis, sn;
std::vector<int> e[50005][2];
inline void dij()
{
FOR(i, 1, n) dis[i] = INF;
dis[S] = 0; q.push(pli(0, S));
while(!q.empty())
{
while(!q.empty() && vis[q.top().second]) q.pop();
if(q.empty()) break;
int now = q.top().second; q.pop();
vis[now] = 1;
PFOR(i, now)
{
int to = r[i].to;
if(!vis[to] && dis[to] > dis[now] + r[i].w)
{
dis[to] = dis[now] + r[i].w;
e[to][0].clear(), e[to][0].emplace_back(now);
q.push(pli(dis[to], to));
}
else if(dis[to] == dis[now] + r[i].w)
e[to][0].emplace_back(now);
}
}
}
std::queue<int> qu;
inline void bfs()
{
qu.push(T); sn[T] = 1;
while(!qu.empty())
{
int now = qu.front(); qu.pop();
for(int to : e[now][0])
if(!sn[to])
{
sn[to] = 1;
qu.push(to);
}
}
}
LL f[50005][2];
std::unordered_map<LL, std::bitset<50005>> mp;
std::bitset<50005> ar[50005][2];
inline void dfs(int now, int ed, int t)
{
vis[now] = 1, ar[now][t][now] = 1;
if(now == ed) return void(f[now][t] = 1);
for(int to : e[now][t])
{
if(!vis[to]) dfs(to, ed, t);
f[now][t] += f[to][t];
ar[now][t] |= ar[to][t];
}
}
LL ans;
signed main()
{
#ifndef ONLINE_JUDGE
freopen("test.in", "r", stdin);
freopen("test.out", "w", stdout);
#endif
n = rd(), m = rd(), S = rd(), T = rd();
FOR(i, 1, m)
{
int u = rd(), v = rd(), w = rd();
if(u == v) continue;
Edge_add(u, v, w), Edge_add(v, u, w);
}
dij();
if(dis[T] == INF)
{
printf("%lld", 1ll * n * (n - 1) / 2);
return 0;
}
bfs();
FOR(u, 1, n)
{
if(!sn[u]) {e[u][0].clear(); continue;}
for(int v : e[u][0]) e[v][1].emplace_back(u);
}
vis = 0, dfs(T, S, 0);
vis = 0, dfs(S, T, 1);
FOR(i, 1, n)
f[i][0] *= f[i][1],
mp[f[i][0]][i] = 1,
ar[i][0] |= ar[i][1], ar[i][0].flip();
FOR(i, 1, n)
ans += (mp[f[T][0] - f[i][0]] & ar[i][0]).count();
printf("%lld", ans >> 1);
return 0;
}