nth_element 测试程序

  1 /********************************************************************
  2 created:    2014/04/29 11:35
  3 filename:    nth_element.cpp
  4 author:        Justme0 (http://blog.csdn.net/justme0)
  5 
  6 purpose:    nth_element
  7 *********************************************************************/
  8 
  9 #include <cstdio>
 10 #include <cstdlib>
 11 #include <cstring>
 12 
 13 typedef int Type;
 14 
 15 template <class T>
 16 inline T * copy_backward(const T *first, const T *last, T *result) {
 17     const ptrdiff_t num = last - first;
 18     memmove(result - num, first, sizeof(T) * num);
 19     return result - num;
 20 }
 21 
 22 /*
 23 ** 将 value 插到 last 前面(不包括 last)的区间
 24 ** 此函数保证不会越界(主调函数已判断),因此以 unguarded_ 开头
 25 */
 26 template <class RandomAccessIterator, class T>
 27 void unguarded_linear_insert(RandomAccessIterator last, T value) {
 28     RandomAccessIterator next = last;
 29     --next;
 30     while(value < *next) {
 31         *last = *next;
 32         last = next;
 33         --next;
 34     }
 35     *last = value;
 36 }
 37 
 38 /*
 39 ** 将 last 处的元素插到[first, last)的有序区间
 40 */
 41 template <class RandomAccessIterator>
 42 void linear_insert(RandomAccessIterator first, RandomAccessIterator last) {
 43     Type value = *last;
 44     if (value < *first) { // 若尾比头小,就将整个区间一次性向后移动一个位置
 45         copy_backward(first, last, last + 1);
 46         *first = value;
 47     } else {
 48         unguarded_linear_insert(last, value);
 49     }
 50 }
 51 
 52 template <class RandomAccessIterator>
 53 void insertion_sort(RandomAccessIterator first, RandomAccessIterator last) {
 54     if (first == last) {
 55         return ;
 56     }
 57 
 58     for (RandomAccessIterator ite = first + 1; ite != last; ++ite) {
 59         linear_insert(first, ite);
 60     }
 61 }
 62 
 63 template <class T>
 64 inline const T & median(const T &a, const T &b, const T&c) {
 65     if (a < b) {
 66         if (b < c) {
 67             return b;
 68         } else if (a < c) {
 69             return c;
 70         } else {
 71             return a;
 72         }
 73     } else if (a < c) {
 74         return a;
 75     } else if (b < c) {
 76         return c;
 77     } else {
 78         return b;
 79     }
 80 }
 81 
 82 template <class ForwardIterator1, class ForwardIterator2>
 83 inline void iter_swap(ForwardIterator1 a, ForwardIterator2 b) {
 84     Type tmp = *a;        // 源码中的 T 由迭代器的 traits 得来,这里简化了
 85     *a = *b;
 86     *b = tmp;
 87 }
 88 
 89 /*
 90 ** 设返回值为 mid,则[first, mid)中迭代器指向的值小于等于 pivot;
 91 ** [mid, last)中迭代器指向的值大于等于 pivot
 92 ** 这是 STL 内置的算法,会用于 nth_element, sort 中
 93 ** 笔者很困惑为什么不用 partition
 94 */
 95 template <class RandomAccessIterator, class T>
 96 RandomAccessIterator unguarded_partition(RandomAccessIterator first, RandomAccessIterator last, T pivot) {
 97     while(true) {
 98         while (*first < pivot) {
 99             ++first;
100         }
101         --last;
102         while (pivot < *last) {    // 若 std::partition 的 pred 是 IsLess(pivot),这里将是小于等于
103             --last;
104         }
105         if (!(first < last)) {    // 小于操作只适用于 random access iterator
106             return first;
107         }
108         iter_swap(first, last);
109         ++first;
110     }
111 }
112 
113 template <class RandomAccessIterator>
114 void nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last) {
115     while (last - first > 3) {
116         RandomAccessIterator cut = unguarded_partition(first, last, Type(median(
117             *first,
118             *(first + (last - first) / 2),
119             *(last - 1))));
120         if (cut <= nth) {
121             first = cut;
122         } else {
123             last = cut;
124         }
125     }
126     insertion_sort(first, last);
127 }
128 
129 
130 int main(int argc, char **argv) {
131     int arr[] = {22, 30, 30, 17, 33, 40, 17, 23, 22, 12, 20};
132     int size = sizeof arr / sizeof *arr;
133 
134     nth_element(arr, arr + 5, arr + size);
135 
136     for (int i = 0; i < size; ++i) {
137         printf("%d ", arr[i]);    // 20 12 22 17 17 22 23 30 30 33 40
138     }
139     printf("\n");
140 
141     system("PAUSE");
142     return 0;
143 }

 

posted on 2014-05-01 12:05  jjtx  阅读(613)  评论(0编辑  收藏  举报

导航