算法导论 Chapter 9.3 Selection in worst-case linear time

问题描述:

本节要求以最坏情况下O(n)的时间复杂度找到长度为n的数组中第 i 大的数。

 

解决方案:

《算法导论》上提供了一个算法,该算法实质上是利用了快排中划分的思想,但其通过一些比较复杂的预处理工作保证了快排划分的均匀,

并且能够从理论上证明其最坏情况下的时间复杂度可以达到O(n)。

算法步骤:

1、如图所示,将n个数分成5个一组,共有⌊n/5⌋组。

2、对⌈n/5⌉组(包括可能不到5个数的那组)的组内数据进行直接插入排序,排序完成之后,图中白色的数据即为组内数据的中位数。

    并将这⌈n/5⌉个中位数挑出来,具体做法见后面的代码。

3、递归调用本算法找到这⌈n/5⌉个中位数的中位数,假设它为图中的x。

4、假定n个数是互不相同的(后面会讨论一般情况),以x为枢轴对n个数进行一趟划分,

    使得x左边的数<x,x右边的数>x。假设划分完成之后,x在数组中从左到右排在第k个。

5、如果 i == k,那么x就是我们要找的数,return x; 即可;

    如果 i < k,那么递归调用本算法在x左边的数中继续找第 i 个数;

    如果 i > k,那么递归调用本算法在x右边的数中继续找第 i - k 个数;

 

为什么这个算法是O(n)的?《算法导论》上给予了证明。

证明步骤:

①第1、2、4步都是O(n)的。

其中第2步O(n)是因为,对n/5组进行组内直接插入排序的时间复杂度是(n/5) * (5^2) = 5n,也就是O(n)。

②根据第1、2、3步我们可以知道,由于n个数是互不相同的,且x是中值的中值,

则图中阴影部分的数据肯定都比x大,具体来说,比x大的数据至少有:

其中,⌈n/5⌉是组的个数(包括可能不到5个数的那组),-2是去掉x所在的组以及有可能不到5个数的那组。

同理,在x左上角的那部分数据也肯定比x小,并且也至少有这么多。

所以无论在第5步中是哪种情况,到下一次递归时最多有7n/10+6个数。

③假设本算法的最坏时间复杂度是T(n)的,那么第3步的时间复杂度是T(⌈n/5⌉),第5步的时间复杂度是T(7n/10+6)。

假设对于n<140的情况,找第 i 个数是O(1)的。(后面会解释为什么有这样的假设,以及为什么是140而不是410)

④现在我们只要证明,对于任意的n>0,都成找到一个常数 c 使得T(n) ≤ cn,那么,这个算法就是O(n)的。

假设式中O(n)项的常数因子为a,则有:

T(n) ≤ c ⌈n/5⌉ + c(7n/10 + 6) + an
       ≤ cn/5 + c + 7cn/10 + 6c + an
       = 9cn/10 + 7c + an
       = cn + (-cn/10 + 7c + an)

如果(-cn/10 + 7c + an) ≤ 0,那么T(n) ≤ cn。

(-cn/10 + 7c + an) ≤ 0 ⇒ c ≥ 10a(n/(n - 70))

当n≥140时,n/(n - 70) ≤ 2,只要取c = 20a即可使T(n) ≤ cn,亦即本算法是O(n)的。

 

下面讨论如果n个数中有重复数的情况

由于在算法步骤中的第四步以及证明步骤中的第二步都假定个n个数互不相同,这样才能保证能有3n/10-6个数一定比x小,

同时有3n/10-6个数一定比x大,这样无论i和k(x在数组中从左到右排在第k个)之间的大小关系是什么,总能下进入下一次

递归时排除掉3n/10-6个数。

而事实上,如果n个数中有重复元素与x相同,比如100个数(中间省略的都是2):

2 2 1 2 2 0 2  …… 2 2 1 2 3

这样≤2的数在前99个,大于2的数(只有3)在第100个,如果我们要找第5个数,那么按照前面的算法进入下次递归的是前99个数,

这显然不能满足“至少除掉3n/10-6个数”的假设。为了实现在这种情况下仍然是O(n),需要DIY一下划分算法,具体来说:

将n个数划分成三部分,第一部分<x,第二部分=x,第三部分>x。按上面的例子划分结果是(中间省略的都是2):

1 0 1 2 2 2 2  …… 2 2 2 2 3

我们知道从左至右第一个出现的x在第4个,最后出现的x在第99个,那么:

如果 i >= 4 && i <= 99,那么x就是我们要找的数,return x; 即可;

如果 i < 4,那么递归调用本算法在x左边的数中继续找第 i 个数;

如果 i > 99,那么递归调用本算法在x右边的数中继续找第 i - 99 个数;

这样,就能够保证进入每次“至少除掉3n/10-6个数”了。

 

实现代码:

View Code
 1  int partitionSpecifyPivot(int a[], int beg, int end, int pivotloc, int *pivotNum)
 2  {
 3      int pivot = a[pivotloc];
 4      int i = beg - 1;
 5      int j = beg - 1;
 6  
 7      for (int k = beg; k <= end; ++k)
 8      {
 9          if (a[k] <= pivot)
10          {
11              swap(a, ++j, k);
12          }
13      }
14      for (int k = beg; k <= j; ++k)
15      {
16          if (a[k] < pivot)
17          {
18              swap(a, ++i, k);
19          }
20      }
21      //pivotNum is the number of the elements that equal to the pivot
22      if (pivotNum != NULL)
23      {
24          *pivotNum = j - i;
25      }    
26      return j;
27  }
View Code
 1 int ithSmallestLinear(int a[], int beg, int end, int i)
 2 {
 3     int len = end - beg + 1;
 4 
 5     if (len < 140)
 6     {
 7         insertionSort(a, beg, end);
 8         return beg + i - 1;
 9     }
10     //divide the n elements into ⌊n/5⌋ groups and sort each group
11     for (int j = 0; j != len / 5; ++j)
12     {
13         int b = beg + j * 5;
14         int e = b + 4;
15         insertionSort(a, b, e);
16         //move the median of each group to the front of the array
17         swap(a, beg + j, b + 2);
18     }
19     //find the median of each median
20     int pivotLoc = ithSmallestLinear(a, beg, len / 5 - 1, (len / 5 + 1) / 2);
21     //the number of the elements that equal to the pivot
22     int pivotNum = 0;
23     int pivotEndIndex = partitionSpecifyPivot(a, beg, end, pivotLoc, &pivotNum);    
24     int n = pivotEndIndex - beg + 1;
25     int m = n - pivotNum + 1;
26 
27     if (i >= m && i <= n)
28     {
29         //return the index of the ith smallest element
30         return pivotEndIndex;
31     }
32     else if (i < m)
33     {
34         return ithSmallestLinear(a, beg, pivotEndIndex - pivotNum, i);
35     }
36     else
37     {
38         return ithSmallestLinear(a, pivotEndIndex + 1, end, i - n);
39     }
40 }

 

测试:

首先测试算法的正确性,代码如下:

View Code
 1 #define ARRAY_SIZE 500
 2 #define COUNT 10
 3 
 4 int a[ARRAY_SIZE];
 5 int b[ARRAY_SIZE];
 6 
 7 int main(void)
 8 {
 9     for (int z = 0; z != COUNT; ++z)
10     {
11         result.open("result.txt");
12         randArray(a, ARRAY_SIZE, 1, 9999);
13         copyArray(a, 0, b, 0, ARRAY_SIZE);
14         quickSort(a, 0, ARRAY_SIZE - 1);
15 
16         for (int i = 1; i <= ARRAY_SIZE; ++i)
17         {
18             int resultStd = a[i - 1];
19             int resultTest = b[ithSmallestLinear(b, 0, ARRAY_SIZE - 1, i)];
20 
21 //             std::cout << "i = " << i << " resultTest = " << resultTest 
22 //                       << " resultStd = " << resultStd << std::endl;
23             if (resultTest != resultStd)
24             {
25                 std::cout << "Error" << std::endl;
26                 return - 1;
27             }
28         }
29         std::cout << "test " << z << " done." << std::endl;
30     }
31 
32     return 0;
33 }

关于该算法最坏情况下能保证O(n)的时间复杂度,只需测试其在数组元素随机、有序、相同这三种情况下的时间,代码如下:

 (1000000数量级的测试,根据机器实际情况可以把这个调小点)

View Code
 1  #define ARRAY_SIZE 1000000
 2  #define COUNT 10
 3  
 4  int a[ARRAY_SIZE];
 5  int b[ARRAY_SIZE];
 6  int c[ARRAY_SIZE];
 7  
 8  int main(void)
 9  {
10      for (int z = 0; z != COUNT; ++z)
11      {
12          randArray(a, ARRAY_SIZE, 1, ARRAY_SIZE * 2);        
13          randArray(c, ARRAY_SIZE, 1, 1);
14          copyArray(a, 0, b, 0, ARRAY_SIZE);
15          quickSort(b, 0, ARRAY_SIZE - 1);
16  
17          for (int i = 1; i <= ARRAY_SIZE; ++i)
18          {
19              clock_t start = clock();
20              ithSmallestLinear(a, 0, ARRAY_SIZE - 1, i);
21              std::cout << clock() - start << "ms ";
22              start = clock();
23              ithSmallestLinear(a, 0, ARRAY_SIZE - 1, i);
24              std::cout << clock() - start << "ms ";
25              start = clock();
26              ithSmallestLinear(a, 0, ARRAY_SIZE - 1, i);
27              std::cout << clock() - start << "ms" << std::endl;
28          }
29          std::cout << "test " << z << " done" << std::endl;
30      }
31      return 0;
32  }

 

测试结果:

 可以看到在三种情况下的时间是一个数量级的。

 

文中一些自定义函数的实现见文章“#include”

posted on 2012-11-01 22:57  Snser  阅读(2522)  评论(0编辑  收藏  举报