AtCoder Beginner Contest 302 Ex Ball Collector
考虑如果只询问一次怎么做。连边 \((a_i, b_i)\),对于每个连通块分别考虑。这是 ARC111B,如果一个连通块是树,肯定有一个点不能被选;否则有环,一定能构造一种方案,使得每个点都被选。
那么现在对于每个点都要求,考虑 dfs 时合并当前的 \((a_u, b_u)\),并且使用可撤销并查集。具体而言,把每次的修改都压进栈里,退出一个点就把这些修改全部复原。注意不要路径压缩,使用按秩合并。
code
// Problem: Ex - Ball Collector
// Contest: AtCoder - TOYOTA MOTOR CORPORATION Programming Contest 2023#2 (AtCoder Beginner Contest 302)
// URL: https://atcoder.jp/contests/abc302/tasks/abc302_h
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 200100;
int n, a[maxn], b[maxn], fa[maxn], rnk[maxn], sz[maxn], e[maxn], top, ans, c[maxn];
pair<int*, int> stk[maxn * 50];
vector<int> G[maxn];
int find(int x) {
return fa[x] == x ? x : find(fa[x]);
}
inline void merge(int x, int y) {
x = find(x);
y = find(y);
stk[++top] = make_pair(&ans, ans);
if (x == y) {
ans -= (e[x] == sz[x] - 1 ? sz[x] - 1 : sz[x]);
stk[++top] = make_pair(e + x, e[x]);
++e[x];
ans += (e[x] == sz[x] - 1 ? sz[x] - 1 : sz[x]);
return;
}
ans -= (e[x] == sz[x] - 1 ? sz[x] - 1 : sz[x]);
ans -= (e[y] == sz[y] - 1 ? sz[y] - 1 : sz[y]);
if (rnk[x] <= rnk[y]) {
stk[++top] = make_pair(fa + x, fa[x]);
fa[x] = y;
stk[++top] = make_pair(sz + y, sz[y]);
sz[y] += sz[x];
stk[++top] = make_pair(e + y, e[y]);
e[y] += e[x] + 1;
ans += (e[y] == sz[y] - 1 ? sz[y] - 1 : sz[y]);
if (rnk[x] == rnk[y]) {
stk[++top] = make_pair(rnk + y, rnk[y]);
++rnk[y];
}
} else {
stk[++top] = make_pair(fa + y, fa[y]);
fa[y] = x;
stk[++top] = make_pair(sz + x, sz[x]);
sz[x] += sz[y];
stk[++top] = make_pair(e + x, e[x]);
e[x] += e[y] + 1;
ans += (e[x] == sz[x] - 1 ? sz[x] - 1 : sz[x]);
}
}
inline void undo() {
*stk[top].fst = stk[top].scd;
--top;
}
void dfs(int u, int fa) {
int lsttop = top;
merge(a[u], b[u]);
c[u] = ans;
for (int v : G[u]) {
if (v == fa) {
continue;
}
dfs(v, u);
}
while (top > lsttop) {
undo();
}
}
void solve() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d%d", &a[i], &b[i]);
fa[i] = i;
rnk[i] = sz[i] = 1;
}
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
dfs(1, -1);
for (int i = 2; i <= n; ++i) {
printf("%d ", c[i]);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}