算法导论 Chapter 9.3 Selection in worst-case linear time
问题描述:
本节要求以最坏情况下O(n)的时间复杂度找到长度为n的数组中第 i 大的数。
解决方案:
《算法导论》上提供了一个算法,该算法实质上是利用了快排中划分的思想,但其通过一些比较复杂的预处理工作保证了快排划分的均匀,
并且能够从理论上证明其最坏情况下的时间复杂度可以达到O(n)。
算法步骤:
1、如图所示,将n个数分成5个一组,共有⌊n/5⌋组。
2、对⌈n/5⌉组(包括可能不到5个数的那组)的组内数据进行直接插入排序,排序完成之后,图中白色的数据即为组内数据的中位数。
并将这⌈n/5⌉个中位数挑出来,具体做法见后面的代码。
3、递归调用本算法找到这⌈n/5⌉个中位数的中位数,假设它为图中的x。
4、假定n个数是互不相同的(后面会讨论一般情况),以x为枢轴对n个数进行一趟划分,
使得x左边的数<x,x右边的数>x。假设划分完成之后,x在数组中从左到右排在第k个。
5、如果 i == k,那么x就是我们要找的数,return x; 即可;
如果 i < k,那么递归调用本算法在x左边的数中继续找第 i 个数;
如果 i > k,那么递归调用本算法在x右边的数中继续找第 i - k 个数;
为什么这个算法是O(n)的?《算法导论》上给予了证明。
证明步骤:
①第1、2、4步都是O(n)的。
其中第2步O(n)是因为,对n/5组进行组内直接插入排序的时间复杂度是(n/5) * (5^2) = 5n,也就是O(n)。
②根据第1、2、3步我们可以知道,由于n个数是互不相同的,且x是中值的中值,
则图中阴影部分的数据肯定都比x大,具体来说,比x大的数据至少有:
其中,⌈n/5⌉是组的个数(包括可能不到5个数的那组),-2是去掉x所在的组以及有可能不到5个数的那组。
同理,在x左上角的那部分数据也肯定比x小,并且也至少有这么多。
所以无论在第5步中是哪种情况,到下一次递归时最多有7n/10+6个数。
③假设本算法的最坏时间复杂度是T(n)的,那么第3步的时间复杂度是T(⌈n/5⌉),第5步的时间复杂度是T(7n/10+6)。
假设对于n<140的情况,找第 i 个数是O(1)的。(后面会解释为什么有这样的假设,以及为什么是140而不是410)
④现在我们只要证明,对于任意的n>0,都成找到一个常数 c 使得T(n) ≤ cn,那么,这个算法就是O(n)的。
假设式中O(n)项的常数因子为a,则有:
T(n) ≤ c ⌈n/5⌉ + c(7n/10 + 6) + an
≤ cn/5 + c + 7cn/10 + 6c + an
= 9cn/10 + 7c + an
= cn + (-cn/10 + 7c + an)
如果(-cn/10 + 7c + an) ≤ 0,那么T(n) ≤ cn。
(-cn/10 + 7c + an) ≤ 0 ⇒ c ≥ 10a(n/(n - 70))
当n≥140时,n/(n - 70) ≤ 2,只要取c = 20a即可使T(n) ≤ cn,亦即本算法是O(n)的。
下面讨论如果n个数中有重复数的情况
由于在算法步骤中的第四步以及证明步骤中的第二步都假定个n个数互不相同,这样才能保证能有3n/10-6个数一定比x小,
同时有3n/10-6个数一定比x大,这样无论i和k(x在数组中从左到右排在第k个)之间的大小关系是什么,总能下进入下一次
递归时排除掉3n/10-6个数。
而事实上,如果n个数中有重复元素与x相同,比如100个数(中间省略的都是2):
2 2 1 2 2 0 2 …… 2 2 1 2 3
这样≤2的数在前99个,大于2的数(只有3)在第100个,如果我们要找第5个数,那么按照前面的算法进入下次递归的是前99个数,
这显然不能满足“至少除掉3n/10-6个数”的假设。为了实现在这种情况下仍然是O(n),需要DIY一下划分算法,具体来说:
将n个数划分成三部分,第一部分<x,第二部分=x,第三部分>x。按上面的例子划分结果是(中间省略的都是2):
1 0 1 2 2 2 2 …… 2 2 2 2 3
我们知道从左至右第一个出现的x在第4个,最后出现的x在第99个,那么:
如果 i >= 4 && i <= 99,那么x就是我们要找的数,return x; 即可;
如果 i < 4,那么递归调用本算法在x左边的数中继续找第 i 个数;
如果 i > 99,那么递归调用本算法在x右边的数中继续找第 i - 99 个数;
这样,就能够保证进入每次“至少除掉3n/10-6个数”了。
实现代码:
1 int partitionSpecifyPivot(int a[], int beg, int end, int pivotloc, int *pivotNum) 2 { 3 int pivot = a[pivotloc]; 4 int i = beg - 1; 5 int j = beg - 1; 6 7 for (int k = beg; k <= end; ++k) 8 { 9 if (a[k] <= pivot) 10 { 11 swap(a, ++j, k); 12 } 13 } 14 for (int k = beg; k <= j; ++k) 15 { 16 if (a[k] < pivot) 17 { 18 swap(a, ++i, k); 19 } 20 } 21 //pivotNum is the number of the elements that equal to the pivot 22 if (pivotNum != NULL) 23 { 24 *pivotNum = j - i; 25 } 26 return j; 27 }
1 int ithSmallestLinear(int a[], int beg, int end, int i) 2 { 3 int len = end - beg + 1; 4 5 if (len < 140) 6 { 7 insertionSort(a, beg, end); 8 return beg + i - 1; 9 } 10 //divide the n elements into ⌊n/5⌋ groups and sort each group 11 for (int j = 0; j != len / 5; ++j) 12 { 13 int b = beg + j * 5; 14 int e = b + 4; 15 insertionSort(a, b, e); 16 //move the median of each group to the front of the array 17 swap(a, beg + j, b + 2); 18 } 19 //find the median of each median 20 int pivotLoc = ithSmallestLinear(a, beg, len / 5 - 1, (len / 5 + 1) / 2); 21 //the number of the elements that equal to the pivot 22 int pivotNum = 0; 23 int pivotEndIndex = partitionSpecifyPivot(a, beg, end, pivotLoc, &pivotNum); 24 int n = pivotEndIndex - beg + 1; 25 int m = n - pivotNum + 1; 26 27 if (i >= m && i <= n) 28 { 29 //return the index of the ith smallest element 30 return pivotEndIndex; 31 } 32 else if (i < m) 33 { 34 return ithSmallestLinear(a, beg, pivotEndIndex - pivotNum, i); 35 } 36 else 37 { 38 return ithSmallestLinear(a, pivotEndIndex + 1, end, i - n); 39 } 40 }
测试:
首先测试算法的正确性,代码如下:
1 #define ARRAY_SIZE 500 2 #define COUNT 10 3 4 int a[ARRAY_SIZE]; 5 int b[ARRAY_SIZE]; 6 7 int main(void) 8 { 9 for (int z = 0; z != COUNT; ++z) 10 { 11 result.open("result.txt"); 12 randArray(a, ARRAY_SIZE, 1, 9999); 13 copyArray(a, 0, b, 0, ARRAY_SIZE); 14 quickSort(a, 0, ARRAY_SIZE - 1); 15 16 for (int i = 1; i <= ARRAY_SIZE; ++i) 17 { 18 int resultStd = a[i - 1]; 19 int resultTest = b[ithSmallestLinear(b, 0, ARRAY_SIZE - 1, i)]; 20 21 // std::cout << "i = " << i << " resultTest = " << resultTest 22 // << " resultStd = " << resultStd << std::endl; 23 if (resultTest != resultStd) 24 { 25 std::cout << "Error" << std::endl; 26 return - 1; 27 } 28 } 29 std::cout << "test " << z << " done." << std::endl; 30 } 31 32 return 0; 33 }
关于该算法最坏情况下能保证O(n)的时间复杂度,只需测试其在数组元素随机、有序、相同这三种情况下的时间,代码如下:
(1000000数量级的测试,根据机器实际情况可以把这个调小点)
1 #define ARRAY_SIZE 1000000 2 #define COUNT 10 3 4 int a[ARRAY_SIZE]; 5 int b[ARRAY_SIZE]; 6 int c[ARRAY_SIZE]; 7 8 int main(void) 9 { 10 for (int z = 0; z != COUNT; ++z) 11 { 12 randArray(a, ARRAY_SIZE, 1, ARRAY_SIZE * 2); 13 randArray(c, ARRAY_SIZE, 1, 1); 14 copyArray(a, 0, b, 0, ARRAY_SIZE); 15 quickSort(b, 0, ARRAY_SIZE - 1); 16 17 for (int i = 1; i <= ARRAY_SIZE; ++i) 18 { 19 clock_t start = clock(); 20 ithSmallestLinear(a, 0, ARRAY_SIZE - 1, i); 21 std::cout << clock() - start << "ms "; 22 start = clock(); 23 ithSmallestLinear(a, 0, ARRAY_SIZE - 1, i); 24 std::cout << clock() - start << "ms "; 25 start = clock(); 26 ithSmallestLinear(a, 0, ARRAY_SIZE - 1, i); 27 std::cout << clock() - start << "ms" << std::endl; 28 } 29 std::cout << "test " << z << " done" << std::endl; 30 } 31 return 0; 32 }
测试结果:
可以看到在三种情况下的时间是一个数量级的。
文中一些自定义函数的实现见文章“#include”