【POJ3417】闇の連鎖
题目链接:https://www.acwing.com/problem/content/354/
题目大意:给定一张 \(n\) 个点的无向图 , 其中主要边为 \(n - 1\) 条连接这 \(n\) 个点的边 , 附加边为 \(m\) 条任意边 , 求砍断一条主要边与附加边就能把这张图分为两个不联通的子图的方案数
solution
令 \(f_i\) 表示砍断这第 \(i\) 条主要边后 , 还要砍断的附加边的数量 , 对于一条连接 \(a_i\) 与 \(b_i\) 的附加边 , 则应令从 \(a_i\) 到 \(b_i\) 路径上的每一条边的 \(f\) 加 1 , 这可以用树上差分来实现 , 最后统计时 , 如果 \(f_i = 1\) , 那么答案就会增加 1 , 如果 \(f_i = 0\) , 那么答案应该增加 \(m\) (因为第二次可以选择任意一条附加边砍断)
时间复杂度 : \(O(nlogn)\)
code
#include<bits/stdc++.h>
using namespace std;
template <typename T> inline void read(T &FF) {
int RR = 1; FF = 0; char CH = getchar();
for(; !isdigit(CH); CH = getchar()) if(CH == '-') RR = -RR;
for(; isdigit(CH); CH = getchar()) FF = FF * 10 + CH - 48;
FF *= RR;
}
inline void file(string str) {
freopen((str + ".in").c_str(), "r", stdin);
freopen((str + ".out").c_str(), "w", stdout);
}
const int N = 2e5 + 10, Log = 21;
int n, m, ui[N], vi[N], fa[N][Log + 1], ans;
int now, fst[N], nxt[N], num[N], dep[N], dis[N];
void add(int u, int v) {
nxt[++now] = fst[u], fst[u] = now, num[now] = v;
nxt[++now] = fst[v], fst[v] = now, num[now] = u;
}
void pre_lca(int xi) {
dep[xi] = dep[fa[xi][0]] + 1;
for(int i = 1; i <= Log; i++)
fa[xi][i] = fa[fa[xi][i - 1]][i - 1];
for(int i = fst[xi]; i; i = nxt[i])
if(num[i] != fa[xi][0])
fa[num[i]][0] = xi, pre_lca(num[i]);
}
int get_lca(int xi, int yi) {
if(dep[xi] < dep[yi]) swap(xi, yi);
for(int i = Log; i >= 0; i--)
if(dep[fa[xi][i]] >= dep[yi])
xi = fa[xi][i];;
if(xi == yi) return xi;
for(int i = Log; i >= 0; i--)
if(fa[xi][i] != fa[yi][i])
xi = fa[xi][i], yi = fa[yi][i];
return fa[xi][0];
}
void get_ans(int xi) {
for(int i = fst[xi]; i; i = nxt[i])
if(num[i] != fa[xi][0]) {
get_ans(num[i]);
dis[xi] += dis[num[i]];
}
if(xi != 1) {
if(dis[xi] == 0) ans += m;
else if(dis[xi] == 1) ans += 1;
}
}
int main() {
//file("");
int u, v;
read(n), read(m);
for(int i = 1; i < n; i++)
read(u), read(v), add(u, v);
pre_lca(1);
for(int i = 1; i <= m; i++) {
read(u), read(v);
int f = get_lca(u, v);
if(f == u) dis[u]--, dis[v]++;
else if(f == v) dis[v]--, dis[u]++;
else dis[f] -= 2, dis[u]++, dis[v]++;
}
get_ans(1);
cout << ans << endl;
return 0;
}