一个精简版的stl_sort
/************************************************************ *本程序为一个精简版的stl_sort,内部参考了GNU ISO C++ Library *用到了部分C++11的特性,在g++ 4.8.1下编译通过 *************************************************************/ #include <iostream> #include <vector> #include <functional> #include <cstdlib> #include <cassert> #include <ctime> //元素数目最大值 const int kMaxNum = 10000000; //Sort中所用的阀值 const int kThreshold = 16; template <typename ForwardIterator> void IterSwap(ForwardIterator a, ForwardIterator b) { auto tmp = *a; *a = *b; *b = tmp; } //这一部分是和堆相关的算法 /**************************************************************/ template <typename RandomAccessIterator, typename Distance, typename Tp, typename Compare> void PushHeap(RandomAccessIterator first, Distance index, Distance top, Tp val, const Compare& comp) { auto parent = (index - 1) / 2; while (index > top && comp(*(first+parent), val)) { *(first + index) = *(first + parent); index = parent; parent = (index - 1) / 2; } *(first + index) = val; } template <typename RandomAccessIterator, typename Distance, typename Tp, typename Compare> void AdjustHeap(RandomAccessIterator first, Distance index, Distance len, Tp val, const Compare& comp) { Distance top = index; Distance child = 2*index + 2; while (child < len) { if (comp(*(first+child), *(first+child-1))) --child; *(first + index) = *(first + child); index = child; child = 2*index + 2; } if (child == len) { *(first + index) = *(first + child - 1); index = child - 1; } ::PushHeap(first, index, top, val, comp); } template <typename RandomAccessIterator, typename Compare> void MakeHeap(RandomAccessIterator first, RandomAccessIterator last, const Compare& comp) { if (last - first < 2) return; auto len = last - first; auto parent = (len - 2) / 2; for (;;) { ::AdjustHeap(first, parent, len, *(first+parent), comp); if (parent == 0) return; --parent; } } template <typename RandomAccessIterator, typename Compare> void PopHeap(RandomAccessIterator first, RandomAccessIterator last, RandomAccessIterator result, const Compare& comp) { assert(first < last); auto val = *result; *result = *first; ::AdjustHeap(first, 0, last-first, val, comp); } template <typename RandomAccessIterator, typename Compare> void HeapSort(RandomAccessIterator first, RandomAccessIterator last, const Compare& comp) { assert(first < last); while (last - first > 1) { --last; ::PopHeap(first, last, last, comp); } } template <typename RandomAccessIterator, typename Compare> void HeapSelect(RandomAccessIterator first, RandomAccessIterator middle, RandomAccessIterator last, const Compare& comp) { assert(first < last); ::MakeHeap(first, middle, comp); for (RandomAccessIterator i = middle; i < last; ++i) { if (comp(*i,*first)) ::PopHeap(first, middle, i, comp); } } /**************************************************************/ //局部排序算法,内部用堆来处理 template <typename RandomAccessIterator, typename Compare> inline void PartialSort(RandomAccessIterator first, RandomAccessIterator middle, RandomAccessIterator last, const Compare& comp) { assert(first < last); ::HeapSelect(first, middle, last, comp); ::HeapSort(first, middle, comp); } //3基准中位数选择算法 template <typename ForwardIterator, typename Compare> void MoveMedianFirst(ForwardIterator a, ForwardIterator b, ForwardIterator c, const Compare& comp) { if (comp(*a, *b)) { if (comp(*b, *c)) { ::IterSwap(a, b); } else if (comp(*a, *c)) { ::IterSwap(a, c); } } else if (comp(*a, *c)) { return; } else if (comp(*b, *c)) { ::IterSwap(a, c); } else { ::IterSwap(a, b); } } //划分函数 template <typename RandomAccessIterator, typename Tp, typename Compare> RandomAccessIterator Partition(RandomAccessIterator first, RandomAccessIterator last, const Tp& pivot, const Compare& comp) { for (;;) { while (comp(*first, pivot)) ++first; --last; while (comp(pivot, *last)) --last; if (!(first < last)) return first; ::IterSwap(first, last); ++first; } } //根据基准点进行划分 template <typename RandomAccessIterator, typename Compare> inline RandomAccessIterator PartitionPivot(RandomAccessIterator first, RandomAccessIterator last, const Compare& comp) { assert(first < last); RandomAccessIterator mid = first + (last - first) / 2; ::MoveMedianFirst(first, mid, (last - 1), comp); return ::Partition(first + 1, last, *first, comp); } //内省排序算法 template<typename RandomAccessIterator, typename Size, typename Compare> void IntrosortLoop(RandomAccessIterator first, RandomAccessIterator last, Size depth_limit, const Compare& comp) { while (last - first > kThreshold) { if (depth_limit == 0) { ::PartialSort(first, last, last, comp); return; } --depth_limit; RandomAccessIterator cut = ::PartitionPivot(first, last, comp); ::IntrosortLoop(cut, last, depth_limit, comp); last = cut; } } //插入排序算法 template <typename RandomAccessIterator, typename Compare> void FinalInsertionSort(RandomAccessIterator first, RandomAccessIterator last, const Compare& comp) { assert(first < last); auto it = first; auto tmp = *first; for (auto i = first+1; i != last; ++i) { it = i-1; tmp = *i; while (it >= first && comp(tmp, *it)) { *(it+1) = *it; --it; } *(it+1) = tmp; } } template <class Size> inline Size lg(Size n) { assert(n > 0); Size k = 0; for (; n != 1; n >>= 1) ++k; return k; } //Sort函数模板 template <typename RandomAccessIterator, typename Compare> inline void Sort(RandomAccessIterator first, RandomAccessIterator last, const Compare& comp) { assert(first < last); ::IntrosortLoop(first, last, ::lg(last - first) * 2, comp); ::FinalInsertionSort(first, last, comp); } //初始化数组元素,用rand函数产生伪随机数 void Init(std::vector<int>& array) { for (std::vector<int>::size_type si = 0; si < array.size(); ++si) { array[si] = rand(); } } //如果排序结果正确则返回排序所用时间 double Test(std::vector<int>& array) { if (array.size() < 2) return 0.0; clock_t start_time = clock(); ::Sort(array.begin(), array.end(), std::less<int>()); double used_time = static_cast<double>(clock()-start_time); for (std::vector<int>::size_type si = 0; si < array.size()-1; ++si) { assert(array[si] <= array[si+1]); } return used_time; } int main() { srand((unsigned)time(NULL)); std::vector<int> array(kMaxNum); double used_time = 0.0; for (int i = 0; i < 1; ++i) { Init(array); used_time += Test(array); } double avg_used_time = used_time / 3 / CLOCKS_PER_SEC; std::cout << "Average used time is " << avg_used_time << "(s)" << std::endl; return 0; }