砍树
砍树
给定一棵由 $n$ 个结点组成的树以及 $m$ 个不重复的无序数对 $(a_1,b_1),(a_2,b_2), \ldots ,(a_m,b_m)$,其中 $a_i$ 互不相同,$b_i$ 互不相同,$a_i \ne b_j \ (1 \leq i,j \leq m)$。
小明想知道是否能够选择一条树上的边砍断,使得对于每个 $(a_i,b_i)$ 满足 $a_i$ 和 $b_i$ 不连通,如果可以则输出应该断掉的边的编号(编号按输入顺序从 $1$ 开始),否则输出 $−1$。
输入格式
输入共 $n+m$ 行,第一行为两个正整数 $n,m$。
后面 $n−1$ 行,每行两个正整数 $x_i,y_i$ 表示第 $i$ 条边的两个端点。
后面 $m$ 行,每行两个正整数 $a_i,b_i$。
输出格式
一行一个整数,表示答案,如有多个答案,输出编号最大的一个。
数据范围
对于 $30\%$ 的数据,保证 $1 < n \leq 1000$。
对于 $100\%$ 的数据,保证 $1 < n \leq {10}^5$,$1 \leq m \leq \frac{n}{2}$。
输入样例:
6 2 1 2 2 3 4 3 2 5 6 5 3 6 4 5
输出样例:
4
样例解释
断开第 $2$ 条边后形成两个连通块:$\{3,4\}$,$\{1,2,5,6\}$,满足 $3$ 和 $6$ 不连通,$4$ 和 $5$ 不连通。
断开第 $4$ 条边后形成两个连通块:$\{1,2,3,4\}$,$\{5,6\}$,同样满足 $3$ 和 $6$ 不连通,$4$ 和 $5$ 不连通。
$4$ 编号更大,因此答案为 $4$。
解题思路
由于在树中任意两点间的路径是固定的,因此要使得数对$(a_i,b_i)$对应的两个点不连通,那么只需要把$a_i \to b_i$路径上的任意一条边砍掉即可。因此可以枚举所有的数对$(a_i,b_i)$,并对$a_i \to b_i$路径上的所有边都累加$1$,表示砍掉这条边可以使得一个数对不连通。最后只需要遍历树中所有边找出被累加了$m$次的边并找到最大的编号。
由于需要对树中某条路径上的边都加上某个数,因此需要用到树上差分中的边差分,关于树上差分的简单介绍可以参考链接。同时还需要用哈希表存边$(a_i,b_i)$与编号的映射,为了避免使用 std::map 这里把二元组$(a_i,b_i)$映射为$a_i \times 100001 + b_i$,由于$b_i \leq 10^5$因此$P$取到$100001$,相当于把$P$进制数转换为十进制数。然后再用 std::unordered_map 来存边与编号的映射关系。当然如果是cf的题还是老老实实用 std::map 吧(悲。
AC代码如下,时间复杂度为$O(n + m \log{n})$:
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 typedef long long LL; 5 6 const int N = 1e5 + 10, M = N * 2; 7 8 int n, m; 9 int head[N], e[M], ne[M], idx; 10 int fa[N][17], dep[N]; 11 int q[N], hh, tt = -1; 12 unordered_map<LL, int> mp; 13 int d[N]; 14 int ans = -1; 15 16 void add(int v, int w) { 17 e[idx] = w, ne[idx] = head[v], head[v] = idx++; 18 } 19 20 LL get(int x, int y) { 21 return x * 100001ll + y; 22 } 23 24 int lca(int a, int b) { 25 if (dep[a] < dep[b]) swap(a, b); 26 for (int i = 16; i >= 0; i--) { 27 if (dep[fa[a][i]] >= dep[b]) a = fa[a][i]; 28 } 29 if (a == b) return a; 30 for (int i = 16; i >= 0; i--) { 31 if (fa[a][i] != fa[b][i]) a = fa[a][i], b = fa[b][i]; 32 } 33 return fa[a][0]; 34 } 35 36 void dfs(int u, int pre) { 37 for (int i = head[u]; i != -1; i = ne[i]) { 38 if (e[i] != pre) { 39 dfs(e[i], u); 40 d[u] += d[e[i]]; 41 } 42 } 43 if (pre != -1 && d[u] == m) ans = max(ans, mp[get(u, pre)]); 44 } 45 46 int main() { 47 scanf("%d %d", &n, &m); 48 memset(head, -1, sizeof(head)); 49 for (int i = 1; i < n; i++) { 50 int v, w; 51 scanf("%d %d", &v, &w); 52 add(v, w), add(w, v); 53 mp[get(v, w)] = mp[get(w, v)] = i; 54 } 55 memset(dep, 0x3f, sizeof(dep)); 56 dep[0] = 0, dep[1] = 1; 57 q[++tt] = 1; 58 while (hh <= tt) { 59 int t = q[hh++]; 60 for (int i = head[t]; i != -1; i = ne[i]) { 61 if (dep[e[i]] > dep[t] + 1) { 62 dep[e[i]] = dep[t] + 1; 63 q[++tt] = e[i]; 64 fa[e[i]][0] = t; 65 for (int j = 1; j <= 16; j++) { 66 fa[e[i]][j] = fa[fa[e[i]][j - 1]][j - 1]; 67 } 68 } 69 } 70 } 71 for (int i = 0; i < m; i++) { 72 int a, b; 73 scanf("%d %d", &a, &b); 74 d[a]++, d[b]++, d[lca(a, b)] -= 2; 75 } 76 dfs(1, -1); 77 printf("%d", ans); 78 79 return 0; 80 }
2024-04-05 更新代码。
AC 代码如下,时间复杂度为 $O((n + m) \log{n})$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 5, M = N * 2;
int n, m;
int h[N], e[M], id[M], ne[M], idx;
int fa[N][17], dep[N];
int s[N];
int ans;
void add(int u, int v, int w) {
e[idx] = v, id[idx] = w, ne[idx] = h[u], h[u] = idx++;
}
void dfs1(int u, int p) {
for (int i = h[u]; i != -1; i = ne[i]) {
int v = e[i];
if (v == p) continue;
dep[v] = dep[u] + 1;
fa[v][0] = u;
for (int i = 1; i <= 16; i++) {
fa[v][i] = fa[fa[v][i - 1]][i - 1];
}
dfs1(v, u);
}
}
void dfs2(int u, int p) {
for (int i = h[u]; i != -1; i = ne[i]) {
int v = e[i];
if (v == p) continue;
dfs2(v, u);
s[u] += s[v];
if (s[v] == m) ans = max(ans, id[i]);
}
}
int lca(int a, int b) {
if (dep[a] < dep[b]) swap(a, b);
for (int i = 16; i >= 0; i--) {
if (dep[fa[a][i]] >= dep[b]) a = fa[a][i];
}
if (a == b) return a;
for (int i = 16; i >= 0; i--) {
if (fa[a][i] != fa[b][i]) a = fa[a][i], b = fa[b][i];
}
return fa[a][0];
}
int main() {
scanf("%d %d", &n, &m);
memset(h, -1, sizeof(h));
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
add(u, v, i), add(v, u, i);
}
dep[1] = 1;
dfs1(1, -1);
for (int i = 0; i < m; i++) {
int u, v;
scanf("%d %d", &u, &v);
int p = lca(u, v);
s[u]++, s[v]++, s[p] -= 2;
}
dfs2(1, -1);
if (!ans) ans = -1;
printf("%d", ans);
return 0;
}
本文来自博客园,作者:onlyblues,转载请注明原文链接:https://www.cnblogs.com/onlyblues/p/17363557.html