[BZOJ2125]最短路[圆方树]
题意
给定仙人掌,多次询问两点之间的最短路径。
\(n\le 10000, Q\le 10000\)
分析
- 建出圆方树,分路径 lca 是圆点还是方点讨论。
- 预处理出根圆点到每个圆点的最短距离 \(dis\) 。
- 如果 lca 是圆点,那么最短距离就是 \(dis_a+dis_b-2*dis_{lca}\)。
- 否则找到 lca 到 a, b 路径上的第一个圆点 x, y,最短距离即 \(dis_a-dis_x+dis_b-dis_y+dist(x, y)\) 。其中 \(dist(x, y)\) 表示在同一个环中的两个节点 \(x, y\) 之间的最短距离。
- 复杂度 \(O(nlogn+Qlogn)\) 。
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
#define go(u) for(int i = head[u], v = e[i].to; i; i=e[i].lst, v=e[i].to)
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define pb push_back
#define re(x) memset(x, 0, sizeof x)
inline int gi() {
int x = 0,f = 1;
char ch = getchar();
while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar();}
return x * f;
}
template <typename T> inline bool Max(T &a, T b){return a < b ? a = b, 1 : 0;}
template <typename T> inline bool Min(T &a, T b){return a > b ? a = b, 1 : 0;}
const int N = 2e4 + 7;
int n, m, edc, dfn, tp, ndc, ans, Q;
int low[N], pre[N], stk[N], f[N][2], head[N];
int up[N][18], s[N], dis[N], tot[N], dep[N];
vector<int>G[N];
map<pair<int, int>, int> dist;
struct edge {
int lst, to;
edge(){}edge(int lst, int to):lst(lst), to(to){}
}e[N << 1];
void Add(int a, int b) {
e[++edc] = edge(head[a], b), head[a] = edc;
e[++edc] = edge(head[b], a), head[b] = edc;
}
void tarjan(int u, int fa) {
low[u] = pre[u] = ++dfn;
stk[++tp] = u;
go(u)if(v ^ fa) {
if(!low[v]) {
tarjan(v, u);
Min(pre[u], pre[v]);
if(pre[v] >= low[u]) {
G[u].pb(++ndc);
for(int x = -1; x ^ v; )
G[ndc].pb(x = stk[tp--]);
}
}else Min(pre[u], low[v]);
}
}
#define mp make_pair
void dfs(int u, int fa) {
up[u][0] = fa;
for(int i = 1; i <= 17; ++i) up[u][i] = up[up[u][i - 1]][i - 1];
if(u > n) {
tot[u] += dist[mp(fa, G[u][0])];
s[G[u][0]] = dist[mp(fa, G[u][0])];
for(int i = 1; i < G[u].size(); ++i) {
s[G[u][i]] = s[G[u][i - 1]] + dist[mp(G[u][i - 1], G[u][i])];
tot[u] += dist[mp(G[u][i - 1], G[u][i])];
}
tot[u] += dist[mp(fa, G[u][G[u].size() - 1])];
}
for(int i = 0; i < G[u].size(); ++i) {
int v = G[u][i];
dep[v] = dep[u] + 1;
if(u > n) dis[v] = dis[fa] + min(s[v], tot[u] - s[v]);
dfs(v, u);
}
}
int get(int a, int b) {
if(a == b) return 0;
if(dep[a] > dep[b]) swap(a, b);
int x = a, y = b;
for(int i = 17; ~i; --i) if(dep[up[b][i]] >= dep[a]) b = up[b][i];
if(a == b) {
b = y;
for(int i = 17; ~i; --i) if(dep[up[b][i]] > dep[a]) b = up[b][i];
if(a <= n) return dis[y] - dis[a];
return dis[y] - dis[b];
}
for(int i = 17; ~i; --i) if(up[a][i] != up[b][i]){
a = up[a][i];
b = up[b][i];
}
int lca = up[a][0];
if(lca <= n) {
return dis[x] + dis[y] - 2 * dis[lca];
}else {
return dis[x] - dis[a] + dis[y] - dis[b] + min(abs(s[b] - s[a]), tot[lca] - abs(s[b] - s[a]));
}
}
int main() {
n = gi(), m = gi();Q = gi();ndc = n;
rep(i, 1, m) {
int a = gi(), b = gi(), c = gi();
Add(a, b);
dist[make_pair(a, b)] = dist[make_pair(b, a)] = c;
}
tarjan(1, 0);
dep[1] = 1, dfs(1, 0);
while(Q--) {
int a = gi(), b = gi();
printf("%d\n", get(a, b));
}
return 0;
}