算法的 union-find 实践
目录
union-find
各种算法的实现,\(N\) 为结点数。
算法 | connect() |
find() |
---|---|---|
quick-find | \(N\) | \(1\) |
quick-union | 树的高度(最好时 \(1\),最坏时 \(N\)) | 树的高度 |
weighted-quick-union | \(\lg N\) | \(\lg N\) |
path-compressed-weighted-quick-union | 非常接近 \(1\), 小于 \(Anti-Ackermann\) 函数, 在实际应用中 \(<5\) |
同 connect() |
height-quick-union | \(\lg N\) | \(\lg N\) |
\(Ackermann\) 函数是一个增长相当快的函数,远大于阶乘 \(n!\)
\[\begin{align*} A(m, n) = \begin{cases} n + 1 & m = 0, \\ A(m - 1, 1) & m > 0, n = 0, \\ A(m - 1, A(m, n - 1)) & m > 0, n > 0, \\ \end{cases} \\ A(n) = A(n, n) \end{align*} \]\[\begin{align*} A(1, n) &= n + 2, \\ A(2, n) &= 2 n, \\ A(3, n) &= 2 ^ n, \\ A(4, n) &= \underbrace{2^{2^{\cdot^{\cdot^{\cdot^2}}}}}_n. \end{align*} \]\[\begin{align*} A(0) &= 1, \\ A(1) &= 3, \\ A(2) &= 4, \\ A(3) &= 8, \\ A(4) &= 65536, \\ A(5) &= \underbrace{2^{2^{\cdot^{\cdot^{\cdot^2}}}}}_m, m = \underbrace{2^{2^{\cdot^{\cdot^{\cdot^2}}}}}_{65536} \\ \dots.\\ \end{align*} \]光是 \(2^{65536}\) 就有十进制下 \(19729\) 位,\(A(5)\) 大到即使将整个宇宙都用来表示这个数都不可能,而之后的数更是无法想象。
接口
主要关注两个函数:
connect()
:连接两个结点;find()
:获取结点所在的连通分量序号。
其他函数功能比较单一:
connected()
:简单地调用find(v) == find(w)
;count()
:使用一个变量记录。
public interface UnionFind {
// 获取图的动态连通性
/**
* 连接两个结点
*
* @param v 第一个结点
* @param w 第二个结点
*/
public void connect(int v, int w);
/**
* @param v 第一个结点
* @param w 第二个结点
* @return 判断两个结点是否连通
*/
public boolean connected(int v, int w);
/**
* 返回结点所在的连通分量序号
*
* @param v 结点
* @return 所在连通分量序号
*/
public int find(int v);
/**
* @return 连通分量数量
*/
public int count();
/**
* 选择 union-find 的实现
*
* @param type 实现
* @param n 节点数
* @return union-find 实例
*/
public static UnionFind New(String type, int n) {
switch (type) {
case "quick-find":
return new QuickFind(n);
case "quick-union":
return new QuickUnion(n);
case "weighted-quick-union":
return new WeightedQuickUnion(n);
case "path-compressed-weighted-quick-union":
return new PathCompressedWeightedQuickUnion(n);
case "height-quick-union":
return new HeightQuickUnion(n);
default:
return null;
}
}
}
quick-find
- 使用一个数组记录每个结点所处的连通分量序号;
connect()
:每次连接都要遍历数组将处于同一连通分量的结点的连通分量序号改变;find()
:单纯的返回连通分量数组中的值,几乎无开销。
class QuickFind implements UnionFind {
private int[] id;
private int count;
public QuickFind(int n) {
id = new int[n];
for (int i = 0; i != n; ++i)
id[i] = i;
count = n;
}
public void connect(int v, int w) {
int vId = id[v], wId = id[w];
if (vId == wId)
return;
for (int i = id.length - 1; i >= 0; --i)
if (id[i] == vId)
id[i] = wId;
--count;
}
public boolean connected(int v, int w) {
return find(v) == find(w);
}
public int find(int v) {
return id[v];
}
public int count() {
return count;
}
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0, n = id.length; i < n; i++)
sb.append(find(i) + " ");
return sb.toString();
}
}
quick-union
- 使用一个树结构记录结点的连通分量,这个树保存在一个数组中,每棵树都代表一个连通分量,根节点的序号就是这个连通分量的序号;
connect()
:使两个树的根节点相连即可;find()
:从枝向根,得到根节点的序号;- 不能保证比 quick-find 快。
class QuickUnion implements UnionFind {
protected int[] id;
protected int count;
public QuickUnion(int n) {
id = new int[n];
for (int i = 0; i != n; ++i)
id[i] = i;
count = n;
}
public void connect(int v, int w) {
int vId = find(v), wId = find(w);
if (vId != wId) {
id[vId] = wId;
--count;
}
}
public boolean connected(int v, int w) {
return find(v) == find(w);
}
public int find(int v) {
int vId;
while ((vId = id[v]) != v) {
v = vId;
}
return vId;
}
public int count() {
return count;
}
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0, n = id.length; i < n; i++)
sb.append(find(i) + " ");
return sb.toString();
}
}
weighted-quick-union
- 是 quick-union 的变体,改写了
connect()
; - quick-union 可能会将大树连接到小树上,通过一个额外的
int[] sz
记录树的大小,可以保证将小树连接到大树上。 - 通过论证可以保证结点的深度最大为 \(\lg N\)。
class WeightedQuickUnion extends QuickUnion {
protected int[] sz;
public WeightedQuickUnion(int n) {
super(n);
sz = new int[n];
for (int i = 0; i != n; ++i) {
sz[i] = 1;
}
}
public void connect(int v, int w) {
int vId = find(v), wId = find(w);
if (vId != wId) {
if (sz[vId] < sz[wId]) {
id[vId] = wId;
sz[wId] += sz[vId];
} else {
id[wId] = vId;
sz[vId] += sz[wId];
}
--count;
}
}
}
path-compressed-weighted-quick-union
- 是 weighted-quick-union 的变体,改写了
find()
; - 增加了
find()
的环节,将沿途的结点都直接接到根节点上。
class PathCompressedWeightedQuickUnion extends WeightedQuickUnion {
public PathCompressedWeightedQuickUnion(int n) {
super(n);
}
public int find(int v) {
int w = v;
int vId;
while ((vId = id[v]) != v) {
v = vId;
}
while (w != v) {
int t = w;
w = id[w];
id[t] = v;
}
return v;
}
}
height-quick-union
-
是 weighted-quick-union 的变体,改写了
connect()
; -
使用树的高度而不是树的节点数来判断根节点的连接;
-
只有当两个树高度相同时,两个树合并才会使高度增加。
class HeightQuickUnion extends WeightedQuickUnion {
public HeightQuickUnion(int n) {
super(n);
}
public void connect(int v, int w) {
int vId = find(v), wId = find(w);
if (vId != wId) {
if (sz[vId] < sz[wId]) {
id[vId] = wId;
} else if (sz[vId] > sz[wId]) {
id[wId] = vId;
} else {
id[wId] = vId;
++sz[vId];
}
--count;
}
}
}
论证
使用数学归纳法论证 weighted-quick-union 任意结点的深度最大为 \(\lg N\)(对 height-quick-union 也适用):
- 已有两树,树 \(A\) 和树 \(B\),设树 \(A\) 大小为 \(i\),树 \(B\) 大小为 \(j\),且 \(i \le j\);
- 根据
connect()
,较小的树连接到较大的树上,大树的深度不变,小树的深度为 \(\lg i + 1\),而 \(1 + \lg i = \lg {2i} \le \lg {(i + j)}\),证毕。
测试
使用了算法书上的测试用例,读取两个数字,分别为结点数(n)和重复次数(trials)。随机连接两个结点知道所有结点构成一个连通分量,输出耗时和边的数量和 \(\frac{n \ln n}{2}\),验证,边的数量近似于 \(\frac{n \ln n}{2}\)。
public static int count(String type, int n) {
UnionFind uf = UnionFind.New(type, n);
if (uf == null) {
System.out.println("null");
return 0;
}
Random r = new Random();
int edges = 0;
while (uf.count() > 1) {
int i = r.nextInt(n), j = r.nextInt(n);
uf.connect(i, j);
edges++;
}
return edges;
}
public static double mean(int[] edges) {
double sum = 0;
for (int edge : edges) sum += edge;
return sum / (double) edges.length;
}
public static double stddev(int[] edges) {
double mean = mean(edges);
double sum = 0;
for (int edge : edges) sum += Math.pow(edge - mean, 2);
return Math.sqrt(sum / mean);
}
public static void main(String[] args) throws Exception {
int n = Integer.parseInt(args[0]); // number of vertices
int trials = Integer.parseInt(args[1]); // number of trials
String[] types = {
"quick-find",
"quick-union",
"weighted-quick-union",
"path-compressed-weighted-quick-union",
"height-quick-union"
};
for (String type : types) {
long start = System.currentTimeMillis();
// some code
int[] edges = new int[trials]; // record statistics
// repeat the experiment trials times
for (int t = 0; t < trials; t++) {
edges[t] = count(type, n);
}
long finish = System.currentTimeMillis();
long timeElapsed = finish - start;
// report statistics
System.out.println(type);
System.out.println("1/2 n ln n = " + 0.5 * n * Math.log(n));
System.out.println("mean = " + mean(edges));
System.out.println("stddev = " + stddev(edges));
System.out.println("cost time = " + (double) timeElapsed / trials);
System.out.println("-----------------");
}
}
输入为 4000 60
时的情况:
quick-find
1/2 n ln n = 16588.099280204056
mean = 17842.566666666666
stddev = 159.69145825839223
cost time = 7.583333333333333
-----------------
quick-union
1/2 n ln n = 16588.099280204056
mean = 17962.583333333332
stddev = 182.13129529930035
cost time = 18.05
-----------------
weighted-quick-union
1/2 n ln n = 16588.099280204056
mean = 17572.216666666667
stddev = 127.10375102496774
cost time = 0.7666666666666667
-----------------
path-compressed-weighted-quick-union
1/2 n ln n = 16588.099280204056
mean = 17908.05
stddev = 147.37036396873788
cost time = 0.6833333333333333
-----------------
height-quick-union
1/2 n ln n = 16588.099280204056
mean = 17834.516666666666
stddev = 154.96088425590762
cost time = 0.7166666666666667
-----------------
输入为 40000 60
时的情况:
quick-find
1/2 n ln n = 211932.69466192144
mean = 224485.96666666667
stddev = 385.98095877302234
cost time = 916.5
-----------------
quick-union
1/2 n ln n = 211932.69466192144
mean = 225757.96666666667
stddev = 442.2855971228588
cost time = 5778.55
-----------------
weighted-quick-union
1/2 n ln n = 211932.69466192144
mean = 221421.55
stddev = 404.7538433217767
cost time = 9.35
-----------------
path-compressed-weighted-quick-union
1/2 n ln n = 211932.69466192144
mean = 222613.03333333333
stddev = 439.7313101308586
cost time = 7.3
-----------------
height-quick-union
1/2 n ln n = 211932.69466192144
mean = 220146.81666666668
stddev = 468.8915493122704
cost time = 9.366666666666667
-----------------