并查集模板hdu-1213
例题:hdu 1213
//#include <bits/stdc++.h>
#include <iostream>
#include <stack>
#include <string>
#include <queue>
#include <stack>
#include <set>
#include <list>
#include <map>
#include <algorithm>
#include <string.h>
using namespace std;
const int MAXN = 1005;
int s[MAXN];
void init_set(){
for(int i = 1; i <= MAXN; i++)
s[i] = i;
}
int find_set(int x){
return x = s[x]? x: find_set(s[x]);
}
void union_set(int x, int y){
x = find_set(x);
y = find_set(y);
if(x != y) s[x] = s[y];
}
int main(){
int t, n, m, x, y;
cin >> t;
while(t--){
cin >> n >> m;
init_set();
for(int i = 1; i <= m; i++){
cin >> x >> y;
union_set(x, y);
}
int ans = 0;
for(int i = 1; i <= n; i++){
if(s[i] == i) ans++;
}
cout << ans << endl;
}
return 0;
}
在上述程序中,find_set()、union_set()的搜索深度都是O(n),性能较差,下面进行优化,优化之后查找和合并的复杂度都是O(logn)。
1、合并的优化
在合并x 和 y时先搜到他们的根节点,然后再合并这两个根节点,即把一个根节点的集改成另一个根节点。这连个根节点的高度不同,如果把高度较小的那一个合并到较大的集上,能减小树的高度。下面是优化后的代码,再初始化时有height[i],定义元素i的高度。
int height[MAXN];
void init_set(){
for(int i = 1; i <= MAXN; i++){
s[i] = i;
height[i] = 0;
}
}
void union_set(int x, int y){
x = find_set(x);
y = find_set(y);
if(height[x] == height[y]){
height[x] = height[x] + 1;
s[y] = x;
}
else{
if(height[x] < height[y]) s[x] = y;
else s[y] = x;
}
}
2、查询的优化——路径的压缩
在上面的查询程序 find_set()中,查询元素 i 所属的集合需要搜索路径找到根节点,返回的结果也是根节点。这条路径可能很长。如果在返回的时候顺便把 i 所属的集改成根节点,那么下次搜索的时候在 O(1) 的时间内就能得到结果。
程序如下:
int find_set(int x){
if(x != s[x])
s[x] = find_set(s[x]);
return s[x];
}
上面的代码用递归实现,如果数据规模太大,担心爆栈,可以用下面的非递归代码。
int find_set(int x){
int r = x;
while(s[r] != r) //找到根节点
r = s[r];
int i = x, j;
while(i != r){
j = s[i];
s[i] = r;
i = j;
}
return r;
}
3、优化完成的代码
//#include <bits/stdc++.h>
#include <iostream>
#include <stack>
#include <string>
#include <queue>
#include <stack>
#include <set>
#include <list>
#include <map>
#include <algorithm>
#include <string.h>
using namespace std;
const int MAXN = 1005;
int s[MAXN];
int height[MAXN];
void init_set(){
for(int i = 1; i <= MAXN; i++){
s[i] = i;
height[i] = 0;
}
}
int find_set(int x){
int r = x;
while(s[r] != r) //找到根节点
r = s[r];
int i = x, j;
while(i != r){
j = s[i];
s[i] = r;
i = j;
}
return r;
}
void union_set(int x, int y){
x = find_set(x);
y = find_set(y);
if(height[x] == height[y]){
height[x] = height[x] + 1;
s[y] = x;
}
else{
if(height[x] < height[y]) s[x] = y;
else s[y] = x;
}
}
int main(){
int t, n, m, x, y;
cin >> t;
while(t--){
cin >> n >> m;
init_set();
for(int i = 1; i <= m; i++){
cin >> x >> y;
union_set(x, y);
}
int ans = 0;
for(int i = 1; i <= n; i++){
if(s[i] == i) ans++;
}
cout << ans << endl;
}
return 0;
}