【算法】并查集
1. 并查集简介
1.1 什么是并查集
并查集是一种用于管理元素所属集合的数据结构,实现为一个森林,其中每棵树表示一个集合,树中的节点表示对应集合中的元素。
并查集支持两种操作:
1. 合并(merge):合并两个数所属的集合(合并两个树);
2. 查询(find):查询两个数是否在同一个集合中(查询两个数所对应的集合)
值得一提的是,并查集的适用范围很广。基本用于优化时间,当然也有其他的用途。请让我一一道来~。
1.2 并查集算法实现
先看一道题。
1.2.1 P3367 【模板】并查集
给定 \(n\) 个数,\(m\) 次操作。每一个数初始对应一个集合。
操作有两种:
-
合并 \(x\), \(y\) 所在集合;
-
查询 \(x\), \(y\) 是否在同一集合;
1.2.2 初始化
我们规定 \(f_i\) 表示在 \(i\) 所在的并查集中,\(i\) 的父节点的编号。对与初始 \(f_i\),每一个集合都有且仅有一个数,此时 \(i\) 的父节点即可记为 \(i\) 本身,于是就有了 \(f_i = i\)。这就是并查集的初始化。
Code:
for (int i = 1; i <= n; i++) f[i] = i;
1.2.3 查找祖先
查找祖先就是沿着树向上移动,直至找到根节点的过程,其实就是“找爸爸”(通俗易懂)。
Code:
int find(int x) {
if(x == f[x]) return x;
return find(f[x]);
}
这没什么好说的,就是普通的遍历树的过程。
在无任何优化的时候,时间复杂度 \(O(n)\)。
1.2.4 合并操作
设要合并所在集合的两数为 \(x\), \(y\),直觉肯定是直接将 \(x\) 的父节点设为 \(y\),使得树与树之间变得联通。但其实这是错误的。
正确的做法是将一棵树的根节点连到另一棵树的根节点。因为只有这样,才能保证合并后树的形态不会发生变化。
比如:
假设要合并所在集合的两数为 \(3\),\(7\)。不是直接连接 \(3\),\(7\)。而是分别找到其祖先 \(1\),\(6\),再连接 \(1\),\(6\)。
如下图:
这样,合并之后依然是一棵树。
Code:
vois merge(int x, int y) {
f[find(x)] = find(y);
}
时间复杂度 \(O(1)\)
1.2.5 查询操作
想要知道 \(x\) 和 \(y\) 是否在同一集合中。等价于询问 \(x\),\(y\) 是否有共同的祖先,及是否在一棵树内。
Code:
bool check(int x, int y) {
return find(x) == find(y);
}
在无任何优化的时候,时间复杂度为 \(O(n)\)。
1.2.6 代码实现
于是,P3367 【模板】并查集 就能迎刃而解了。
Code:
#include <bits/stdc++.h>
#define ll long long
#define H 19260817
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007
using namespace std;
inline int read() {
rint x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9){print(x/10);putchar(x%10+'0');}
else putchar(x+'0');
return;
}
const int N = 200010;
int n, m, f[N];
int find(int x) {
return (x == f[x] ? x : find(f[x]));
}
signed main() {
n = read(), m = read();
For(i,1,n) f[i] = i;
while(m--) {
int op = read(), x = read(), y = read();
if(op == 1) {
f[find(x)] = find(y);
} else {
cout << (find(x) == find(y) ? 'Y' : 'N') << '\n';
}
}
return 0;
}
2. 并查集优化
2.1 路径压缩
你发现,把 1.2.6 的代码交上去只有 \(70pts\),TLE了三个点,这是因为在无任何优化的并查集,其形态会退化成链。
在某些时候,我们只关心集合与集合之间的关系,并不关心并查集的形态(树的形态)。而我们又想求得更高效的查询效率。这时,我们只需要使用路径压缩。
路径压缩的具体步骤是将所有节点的父节点设为并查集的根节点。
如下图:
这就很好的诠释了路径压缩的原理。
代码实现也很好写,只需要在 \(find\) 函数上动点手脚就行了。
Code:
int find(int x) {
return (x == f[x] ? x : f[x] = find(f[x]));
}
将加了路径压缩的并查集交到 P3367,就可以 A 了。
时间复杂度 \(O(\log n)\)(理论值)
2.2 按秩合并
按秩合并的方法有两种,第一种是按树高合并,第二种是按树的大小合并。
这里我们介绍按树的大小合并的按秩合并方式。
原理很简单,就是每次合并时将“小树”并到“大树”上。
记 \(siz_i\) 表示以 \(i\) 为根的并查集(树)的大小。则合并时,先要比较需合并的并查集大小。再将小的并到大的上即可。
Code:
int find(int x) {
return (x == f[x] ? x : f[x] = find(f[x]));
}
void merge(int x, int y) {
x = find(x), y = find(y);
if(siz[x] < siz[y]) {
siz[y] += siz[x];
f[y] = x;
} else {
siz[x] += siz[y];
f[x] = y;
}
}
时间复杂度均摊 \(O(\log n)\)。
3. 并查集时间复杂度分析
并查集最难的就在这里,我也不会分析。
但是有一个结论需要记住,就是当路径压缩 + 按秩合并同时加上时,其单次查询/查找祖先的时间复杂度为 \(O(\alpha(n))\),其中\(\alpha(n)\) 为反 Ackermann 函数。在 \(n\) 同级的情况下,\(O(\alpha(n))\) 会比 \(O(\log n)\) 快得多。
具体的并查集复杂度证明放在这里,有兴趣可以去看一看。
4. 并查集例题
4.1 P2024 [NOI2001] 食物链
Problem
有三类动物 \(A, B, C\),这三类动物的食物链构成了一个环。\(A\) 吃 \(B\),\(B\) 吃 \(C\),\(C\) 吃 \(A\)。
有 \(k\) 句话,\(N\) 个动物,每句话有两种说法:
- \(X\) 和 \(Y\) 是同类;
- \(X\) 吃 \(Y\);
每句话有真有假,一句话的真假有 \(3\) 种判定方法:
- 当前的话与前面的某些真的话冲突,就是假话;
- 当前的话中 \(X\) 或 \(Y\) 比 \(N\) 大,就是假话;
- 当前的话表示 \(X\) 吃 \(X\),就是假话。
求这 \(k\) 句话中有多少假话
Solve
很模板的一道题,这一类题被称之为 种类并查集。
顾名思义,就是用并查集维护种类关系。
对于每两个动物,有且仅有 \(3\) 种关系:同类,猎物,天敌。每种关系互相联系。
对于每句话的两种说法,种类之间的关系就有 \(6\) 种情况:
如果 \(X\) 和 \(Y\) 是同类;则:
- \(X\) 的同类是 \(Y\) 的同类;
- \(X\) 的猎物是 \(Y\) 的猎物;
- \(X\) 的天敌是 \(Y\) 的天敌;
如果 \(X\) 吃 \(Y\);则:
- \(X\) 的同类是 \(Y\) 的天敌;
- \(X\) 的猎物是 \(Y\) 的同类;
- \(X\) 的天敌是 \(Y\) 的猎物;
设一个动物的编号为 \(p\),其同类的编号为 \(p\),猎物的编号为 \(p + n\),的编号为 \(p + 2 \times n\)。
维护一个种类并查集,每每遇到到一句话,就先看它是否与之前的话产生矛盾,再将其种类与种类合并。
对于第一句话,则有:
merge(x, y); //X 的同类是 Y 的同类;
merge(x + n, y + n); //X 的猎物是 Y 的猎物;
merge(x + 2 * n, y + 2 * n); //X 的天敌是 Y 的天敌;
对于第二句话,则有:
merge(x + n, y); //X 的同类是 Y 的天敌;
merge(x, y + 2 * n); //X 的猎物是 Y 的同类;
merge(x + 2 * n, y + n); //X 的天敌是 Y 的猎物;
怎样判断这句话是否与之前的话产生矛盾呢?其实很简单。
对于第一句话,则有:
如果 \(X\) 的猎物是 \(Y\) 的同类或者 \(X\) 的同类是 \(Y\) 的猎物,则这句话是假话;
对于第二句话,则有:
如果 \(X\) 的同类是 \(Y\) 的同类或者 \(X\) 的同类是 \(Y\) 的猎物,则这句话是假话;
Q:check的时候没有用到“天敌”这一种类关系,是不是就不需要“天敌”这一关系了呢?
A:其实不是,只是check的时候不需要“天敌”这一关系,不代表不需要。比如,\(X\) 是 \(Y\) 的天敌,\(Z\) 是 \(Y\) 的天敌,则能表示 \(X\) 是 \(Z\) 的同类。所以“同类”,“猎物”,“天敌”三个种类关系缺一不可。
时间复杂度 \(O(k \log n)\)。
Code
#include <bits/stdc++.h>
#define ll long long
#define H 19260817
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007
using namespace std;
inline int read() {
rint x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9){print(x/10);putchar(x%10+'0');}
else putchar(x+'0');
return;
}
const int N = 5e4 + 10;
int n, k, f[N * 4], ans;
int find(int x) {
return (x == f[x] ? x : f[x] = find(f[x]));
}
void merge(int x, int y) {
f[find(x)] = find(y);
}
signed main() {
n = read(), k = read();
For(i,1,n*4) f[i] = i;
For(i,1,k) {
int op = read(), x = read(), y = read();
if(x > n || y > n) { ans++; continue; }
if(op == 1) {
if(find(x + n) == find(y) || find(y + n) == find(x)) ans++;
else {
merge(x, y);
merge(x + n, y + n);
merge(x + 2 * n, y + 2 * n);
}
} else {
if(find(x) == find(y) || find(x) == find(y + n)) ans++;
else {
merge(x + n, y);
merge(x, y + 2 * n);
merge(x + 2 * n, y + n);
}
}
}
cout << ans << '\n';
return 0;
}
4.2 P1196 [NOI2002] 银河英雄传说
Problem
有一个被划分为 \(30000\) 列的战场,初始每一列有一个战舰。
给定 \(m\) 次指令,指令有两种:
- 将第 \(i\) 号战舰所在的整个战舰队列接到第 \(j\) 号战舰所在的战舰队列的尾部;
- 第 \(i\) 号战舰与第 \(j\) 号战舰当前是否在同一列中,如果在同一列中,询问它们之间布置有多少战舰;
Solve
也是一个很模板的题,这一类题被称之为 带权并查集。
记 \(num_i\) 表示 \(i\) 所在战舰队列的战舰数量,\(s_i\) 表示 \(i\) 在所在战舰队列到队首的距离。
合并操作:由于 \(j\) 接在 \(i\) 前面,所以 \(i\) 的位置会往后移动 \(num_j\) 个位置。并且将 \(num_j\) 加上 \(num_i\)。
查询操作:先判断第 \(u\) 号战舰与第 \(j\) 号战舰当前是否在同一列中,如果在,则 \(i\) 与 \(j\) 之间的战舰数为 \(|s_i - s_j| - 1\);
时间复杂度 \(O(T \log n)\)。
Code
#include <bits/stdc++.h>
#define ll long long
#define H 19260817
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007
using namespace std;
inline int read() {
rint x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9){print(x/10);putchar(x%10+'0');}
else putchar(x+'0');
return;
}
const int N = 3e4 + 10;
int T, f[N], s[N], num[N];
int find(int x) {
if(x == f[x]) return f[x];
else {
int fa = find(f[x]);
s[x] += s[f[x]];
return f[x] = fa;
}
}
void merge(int x, int y) {
x = find(x), y = find(y);
s[x] += num[y];
f[x] = y;
num[y] += num[x];
num[x] = 0;
}
signed main() {
For(i,1,N) f[i] = i, num[i] = 1, s[i] = 0;
T = read();
while(T--) {
char c;
cin >> c;
int x = read(), y = read();
if(c == 'M') {
merge(x, y);
} else {
if(find(x) != find(y)) cout << -1 << '\n';
else cout << abs(s[x] - s[y]) - 1 << '\n';
}
}
return 0;
}
4.3 P1197 [JSOI2008] 星球大战
Problem
给定 \(n\) 个点 \(m\) 条边的无向图。有 \(k\) 个点即将被删除,删除后其连边将消失。
顺次将 \(k\) 个节点删除,问每次删除后还剩多少个连通块。
Solve
当然可以按照题面模拟此过程。但不过最坏时间复杂度会达到 \(O(nm)\),显然过不了。
“顺次删掉 \(k\) 个节点”的时间复杂度很高,单次操作会达到 \(O(m)\)。有什么办法可以优化呢?
此时只需要倒过来做,把 “顺次删掉 \(k\) 个节点” 变为 “逆次增加 \(k\) 个节点”。
可以先预处理出当 \(k\) 个节点已经被删除之后的连通块数量。再将第 \(k-1,k-2,k-3,\dots,1\) 节点添加。每次添加可以使用并查集判断连通块是否连通。
具体细节详见代码。
时间复杂度 \(O(k \log n)\)。
Code
#include <bits/stdc++.h>
#define ll long long
#define H 19260817
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007
using namespace std;
inline int read() {
rint x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9){print(x/10);putchar(x%10+'0');}
else putchar(x+'0');
return;
}
const int N = 5e6 + 10;
struct Node {
int u, v, nx;
} e[N];
int n, m, k, h[N], tot, ans[N], num, bre[N], f[N];
bool vis[N];
void add(int u, int v) {
e[++tot].u = u, e[tot].v = v, e[tot].nx = h[u], h[u] = tot;
}
int find(int x) {
return (x == f[x] ? x : f[x] = find(f[x]));
}
signed main() {
n = read(), m = read();
For(i,1,n) f[i] = i;
For(i,1,m) {
int u = read(), v = read();
add(u, v), add(v, u);
}
k = read();
For(i,1,k) bre[i] = read(), vis[bre[i]] = 1;
num = n - k;
For(i,1,2*m) {
if(!vis[e[i].u] && !vis[e[i].v]) {
if(find(e[i].u) != find(e[i].v)) {
num--;
f[find(e[i].u)] = find(e[i].v);
}
}
}
ans[k + 1] = num;
FOR(i,k,1) {
vis[bre[i]] = 0;
num++;
for (int j = h[bre[i]]; j; j = e[j].nx) {
if(!vis[e[j].v] && find(bre[i]) != find(e[j].v)) {
num--;
f[find(e[j].v)] = find(bre[i]);
}
}
ans[i] = num;
}
For(i,1,k+1) cout << ans[i] << '\n';
return 0;
}
4.4 P2391 白雪皑皑
Problem
有 \(n\) 个雪花,\(m\) 次操作,第 \(i\) 次操作会把 \((i\times p + q) \mod n + 1\) 与 \((i \times q + p) \mod n + 1\) 之间的雪花染成颜色 \(i\),问 \(\forall i\) 的雪花的颜色。
Solve
这是一道很好的并查集维护序列连通性问题。
首先,一个雪花被染成什么颜色取决于它最后一次被染的颜色。那么,前面的所有操作对一片雪花的染色都“浪费了”。所以我们可以从后往前遍历所有询问,对于每一个被染过色的雪花就不对其进行染色了。这样,由于每个数只会被染一次,期望时间复杂度 \(O(n)\)。
不过,我们无法快速的知道一个区间内哪一个雪花已经染了颜色,哪些雪花没有。
这时,我们可以用并查集维护那些没有被染过色的位置。
我们规定 \(f_i\) 表示区间 \([i,n]\) 中下一个未被染色的位置。
当前点 \(i\) 已经被染过色了,直接将 \(i + 1\) 加入并查集 \(f_i\)(将 \(f_i\) 连向 \(i + 1\))。
Code
#include <bits/stdc++.h>
#define int long long
#define H 19260817
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007
using namespace std;
inline int read() {
rint x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9){print(x/10);putchar(x%10+'0');}
else putchar(x+'0');
return;
}
const int N = 1e6 + 10;
int n, m, p, q, a[N], f[N], Nop, col[N];
int find(int x) {
return (x == f[x] ? x : f[x] = find(f[x]));
}
signed main() {
n = read(), m = read(), p = read(), q = read();
For(i,1,n+5) f[i] = i;
Nop = n;
for (rint i = m; i >= 1 && Nop; i--) {
int l = (i * p + q) % n + 1;
int r = (i * q + p) % n + 1;
if(l > r) swap(l, r);
int x = find(l);
while(x <= r) {
col[x] = i, Nop--, f[x] = find(x + 1);
x = find(x + 1);
}
}
For(i,1,n) printf("%lld\n", col[i]);
return 0;
}