计算序列中第k小的数
作者:jostree 转载请注明出处 http://www.cnblogs.com/jostree/p/4046399.html
使用分治算法,首先选择随机选择轴值pivot,并使的序列中比pivot小的数在pivot左边,比pivot大的数在pivot右边,即快速排序算法中的partition的过程,可以参考:快速排序算法 Quick sort。
进行partition过程后,我们随机选择的轴值为序列的第j个,且其左边有a个数,右边有b个数。
如果j=k,那么说明该轴值就是第k小个数。
如果j>k,说明第k小的数一定在轴值的左边,我们可以递归的查找左侧a个数中的第k大个数。
如果j<k,说明第k小的数一定在轴值的右侧,且其左侧的a+1个数都小于第k小的数,所以我们可以递归的查找右侧b个数中的第k-a-1小的数。
由于每次规模缩小一半,且每次处理的时间为O(n),那么我们可以得到其平均复杂度为:T(n)=T(n/2)+n
根据主定理我们可以得到算法的复杂度为O(n)
代码如下:
1 #include <cstdio> 2 #include <cstdlib> 3 #include <iostream> 4 #include <vector> 5 #include <ctime> 6 using namespace std; 7 void swap(int &a, int &b) 8 { 9 int tmp = a; 10 a = b; 11 b = tmp; 12 } 13 int partation(vector<int> &vec, int begin, int end, int pivot) 14 { 15 swap(vec[pivot], vec[begin]); 16 int tmp = vec[begin]; 17 int p = begin, q = end; 18 while( p < q ) 19 { 20 while( p < q && vec[q] >= tmp) q--; 21 if( p < q ) vec[p++] = vec[q]; 22 while( p < q && vec[p] < tmp ) p++; 23 if( p < q ) vec[q--] = vec[p]; 24 } 25 vec[p] = tmp; 26 return p; 27 } 28 int FindKMin(vector<int> & vec, int begin, int end, int k) 29 { 30 srand(time(0)); 31 int piovt = rand()%(end-begin+1)+begin; 32 int pos = partation(vec, begin, end, piovt); 33 if( pos-begin+1 == k ) 34 { 35 return vec[pos]; 36 } 37 else if( pos-begin+1 > k ) 38 { 39 return FindKMin(vec, begin, pos-1, k); 40 } 41 else 42 { 43 return FindKMin(vec, pos+1, end, k-(pos-begin+1)); 44 } 45 } 46 int main(int argc, char *argv[]) 47 { 48 int n; 49 vector<int> vec; 50 while( cin >> n ) 51 { 52 vec.clear(); 53 int k; 54 cin>>k; 55 int tmp; 56 for( int i = 0 ; i < n ; i++ ) 57 { 58 cin>>tmp; 59 vec.push_back(tmp); 60 } 61 cout<<FindKMin(vec, 0, vec.size()-1, k)<<endl; 62 } 63 }
该方法的最坏复杂度为$O(n^2)$,最坏情况是第一阶段每次选择的轴都为最大的数,并递归计算前n-1个数的第k小数,直到第k小数为这个序列的最大值为止。第二阶段每次选择的轴为最小的数,直到只剩下第k小数这一个数。那么该复杂度为$O(n^2)$。
我们可以使用另一种方法来替代随机选择轴值,并使得该算法的最坏情况的复杂度也为O(n),选择轴值方法如下:
1. 将输入数组的n个元素划分为$\lfloor n/5 \rfloor$,每组5个元素,至多只有一组由剩下的nmod5个元素组成。
2. 寻找$\lceil n/5 \rceil$个组中每组的中位数,即对其进行排序,从而找到$\lceil n/5 \rceil$个中位数,并对这$\lceil n/5 \rceil$个中位数组成的数组继续递归调用找出其轴值。
使用该方法找到的轴值并不是数组真正的中位数。但是它具有一定的性质,在大于轴值的那些中位数的组且不包括最后个数少于5的那个组中,每组至少有3个数大于轴值。不计算这两个组,大于轴值的元素个数至少为:
\begin{equation} 3(\lceil \frac{1}{2}\lceil \frac{n}{5}\rceil\rceil -2) \geq \frac{3n}{10}-6 \end{equation}
从而时间复杂度为:$T(n) \leq T(\lceil n/5 \rceil ) + T(7n/10+6) + O(n) = O(n)$
选择轴值的代码,需要建立一个类unit来保存第i个数的值和其位置,并且最终返回轴值的位置。main函数包括了数组长度为4-6的数全排列的测试。
代码如下:
1 #include <cstdio> 2 #include <cstdlib> 3 #include <iostream> 4 #include <vector> 5 #include <algorithm> 6 using namespace std; 7 class unit 8 { 9 public: 10 int x, num; 11 unit(int xx=0, int nn=0) 12 { 13 x = xx; 14 num = nn; 15 } 16 bool operator < (const unit &a) const 17 { 18 return this->x < a.x; 19 } 20 }; 21 int getp(vector<unit> vec) 22 { 23 if( vec.size() <= 5 ) 24 { 25 sort(vec.begin(), vec.end()); 26 return vec[vec.size()/2].num; 27 } 28 vector<unit> small; 29 unit tmp; 30 for( int i = 0 ; i < vec.size()-vec.size()%5 ; i+=5 )//i is the start index in the team 31 { 32 sort(vec.begin()+i, vec.begin()+i+5); 33 // for( int j = 0 ; j < 5 ; j++ ) 34 // { 35 // cout<<vec[i+j].x<<" "; 36 // } 37 // cout<<endl; 38 tmp.x = vec[i+2].x; 39 tmp.num = vec[i+2].num; 40 small.push_back(tmp); 41 } 42 int remain = vec.size()%5; 43 int teamnum = vec.size()/5; 44 if( remain != 0 ) 45 { 46 sort(vec.begin()+teamnum*5, vec.end()); 47 tmp.x = vec[teamnum*5+remain/2].x; 48 tmp.num = vec[teamnum*5+remain/2].num; 49 small.push_back(tmp); 50 } 51 // for( int i = 0 ; i < small.size() ; i++ ) 52 // { 53 // cout<<small[i].x<<" "; 54 // } 55 // cout<<endl; 56 return getp(small); 57 } 58 int getpivot(const vector<int> vec) 59 { 60 vector<unit> vecunit; 61 unit tmp; 62 for( int i = 0 ; i < vec.size() ; i++ ) 63 { 64 tmp.x = vec[i]; 65 tmp.num = i; 66 vecunit.push_back(tmp); 67 } 68 return getp(vecunit); 69 } 70 int main(int argc, char *argv[]) 71 { 72 vector<int> a; 73 for( int i = 4 ; i < 7 ; i++ ) 74 { 75 a.clear(); 76 for( int j = 0 ; j < i ; j++ ) 77 { 78 a.push_back(j); 79 } 80 cout<<endl; 81 do{ 82 cout<<"array is "; 83 for( int k = 0 ; k < a.size() ; k++ ) 84 { 85 cout<<a[k]<<" "; 86 } 87 cout<<endl; 88 cout<<"result "<<getpivot(a)<<endl; 89 }while(next_permutation(a.begin(), a.end())); 90 } 91 }