bzoj3910 火车
Description
A 国有n 个城市,城市之间有一些双向道路相连,并且城市两两之间有唯一
路径。现在有火车在城市 a,需要经过m 个城市。火车按照以下规则行驶:每次
行驶到还没有经过的城市中在 m 个城市中最靠前的。现在小 A 想知道火车经过
这m 个城市后所经过的道路数量。
Input
第一行三个整数 n、m、a,表示城市数量、需要经过的城市数量,火车开始
时所在位置。
接下来 n-1 行,每行两个整数 x和y,表示 x 和y之间有一条双向道路。
接下来一行 m 个整数,表示需要经过的城市。
Output
一行一个整数,表示火车经过的道路数量。
Sample Input
5 4 2
1 2
2 3
3 4
4 5
4 3 1 5
Sample Output
9
Hint
N<=500000 ,M<=400000
\(Lca\) + 并查集
一开始想到用\(lca\)求走过的边数了,但是将点标记的时候我用的是暴力向上跳,结果\(Tle\)了。
看了题解发现可以用并查集维护这个点是否被走过了。具体看代码。
#include <iostream>
#include <cstdio>
#include <cctype>
using namespace std;
inline long long read() {
long long s = 0, f = 1; char ch;
while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
return s * f;
}
const int N = 5e5 + 5, M = 4e5 + 5;
int n, m, x, cnt;
long long ans;
int a[M], f[N], fa[N][21], dep[N], head[N];
struct edge { int to, nxt; } e[N << 1];
void add(int x, int y) {
e[++cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y;
}
int find(int x) {
return x == f[x] ? x : f[x] = find(f[x]);
}
void get_dep(int x, int Fa) {
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == Fa) continue;
dep[y] = dep[x] + 1; fa[y][0] = x;
get_dep(y, x);
}
}
void make_fa() {
for(int i = 1;i <= 20; i++) {
for(int j = 1;j <= n; j++) {
fa[j][i] = fa[fa[j][i - 1]][i - 1];
}
}
}
int LCA(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20;i >= 0; i--) {
if(dep[x] - dep[y] >= (1 << i)) x = fa[x][i];
}
if(x == y) return x;
for(int i = 20;i >= 0; i--) {
if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
}
return fa[x][0];
}
void init() {
n = read(); m = read(); x = read();
for(int i = 1;i <= n; i++) f[i] = i;
for(int i = 1, x, y;i <= n - 1; i++) {
x = read(); y = read();
add(x, y); add(y, x);
}
for(int i = 1;i <= m; i++) a[i] = read();
}
void work() {
get_dep(x, 0); make_fa();
for(int i = 1;i <= m; i++) {
int tx = find(x), ty = find(a[i]);
if(tx == ty) continue;
int lca = LCA(x, a[i]);
ans += dep[x] + dep[a[i]] - (dep[lca] << 1);
lca = find(lca);
int tmp = tx;
while(find(tmp) != lca) {
int father = find(tmp);
f[father] = lca; tmp = fa[father][0];
}
tmp = ty;
while(find(tmp) != lca) {
int father = find(tmp);
f[father] = lca; tmp = fa[father][0];
}
x = a[i];
}
printf("%lld", ans);
}
int main() {
init();
work();
return 0;
}