STL 源码剖析:sort
概览
我大抵是太闲了。
此文章同步发表于我的知乎。
sort
作为最常用的 STL 之一,大多数人对于其了解仅限于快速排序。
听说其内部实现还包括插入排序和堆排序,于是很好奇,决定通过源代码一探究竟。
个人习惯使用 DEV-C++,不知道其他的编译器会不会有所不同,现阶段也不是很关心。
这个文章并不是析完之后的总结,而是边剖边写。不免有个人的猜测。而且由于本人英语极其差劲,大抵会犯一些憨憨错误。
源码部分
sort
首先,在 Dev 中输入以下代码:
#include<algorithm>
sort();
然后按住 ctrl,鼠标左键 sort
,就可以跳转到头文件 stl_algo.h
,并可以看到这个:
/**
* @brief Sort the elements of a sequence.
* @ingroup sorting_algorithms
* @param __first An iterator.
* @param __last Another iterator.
* @return Nothing.
*
* Sorts the elements in the range @p [__first,__last) in ascending order,
* such that for each iterator @e i in the range @p [__first,__last-1),
* *(i+1)<*i is false.
*
* The relative ordering of equivalent elements is not preserved, use
* @p stable_sort() if this is needed.
*/
template<typename _RandomAccessIterator>
inline void
sort(_RandomAccessIterator __first, _RandomAccessIterator __last)
{
// concept requirements
__glibcxx_function_requires(_Mutable_RandomAccessIteratorConcept<
_RandomAccessIterator>)
__glibcxx_function_requires(_LessThanComparableConcept<
typename iterator_traits<_RandomAccessIterator>::value_type>)
__glibcxx_requires_valid_range(__first, __last);
std::__sort(__first, __last, __gnu_cxx::__ops::__iter_less_iter());
}
/**
* @brief Sort the elements of a sequence using a predicate for comparison.
* @ingroup sorting_algorithms
* @param __first An iterator.
* @param __last Another iterator.
* @param __comp A comparison functor.
* @return Nothing.
*
* Sorts the elements in the range @p [__first,__last) in ascending order,
* such that @p __comp(*(i+1),*i) is false for every iterator @e i in the
* range @p [__first,__last-1).
*
* The relative ordering of equivalent elements is not preserved, use
* @p stable_sort() if this is needed.
*/
template<typename _RandomAccessIterator, typename _Compare>
inline void
sort(_RandomAccessIterator __first, _RandomAccessIterator __last,
_Compare __comp)
{
// concept requirements
__glibcxx_function_requires(_Mutable_RandomAccessIteratorConcept<
_RandomAccessIterator>)
__glibcxx_function_requires(_BinaryPredicateConcept<_Compare,
typename iterator_traits<_RandomAccessIterator>::value_type,
typename iterator_traits<_RandomAccessIterator>::value_type>)
__glibcxx_requires_valid_range(__first, __last);
std::__sort(__first, __last, __gnu_cxx::__ops::__iter_comp_iter(__comp));
}
注释、模板和函数参数不再解释,我们需要关注的是函数体。
但是,中间那一段没看懂……
点进去,是一堆看不懂的 #define
。
查了一下,感觉这东西不是我这个菜鸡能掌握的。
有兴趣的戳这里。
那么接下来,就应该去到函数 __sort
来一探究竟了。
__sort
通过同样的方法,继续在 stl_algo.h
里找到 __sort
的源代码。
template<typename _RandomAccessIterator, typename _Compare>
inline void
__sort(_RandomAccessIterator __first, _RandomAccessIterator __last,
_Compare __comp)
{
if (__first != __last)
{
std::__introsort_loop(__first, __last,
std::__lg(__last - __first) * 2,
__comp);
std::__final_insertion_sort(__first, __last, __comp);
}
}
同样,只看函数体部分。
一般来说,sort(a,a+n)
是对于区间 \([0,n)\) 进行排序,所以排序的前提是 __first != __last
。
如果能排序,那么通过两种方式:
-
__introsort_loop
翻译过来,好像叫内联循环。 -
__final_insertion_sort
应该叫末端插入排序吧。
一部分一部分的看。
__introsort_loop
/// This is a helper function for the sort routine.
template<typename _RandomAccessIterator, typename _Size, typename _Compare>
void
__introsort_loop(_RandomAccessIterator __first,
_RandomAccessIterator __last,
_Size __depth_limit, _Compare __comp)
{
while (__last - __first > int(_S_threshold))
{
if (__depth_limit == 0)
{
std::__partial_sort(__first, __last, __last, __comp);
return;
}
--__depth_limit;
_RandomAccessIterator __cut =
std::__unguarded_partition_pivot(__first, __last, __comp);
std::__introsort_loop(__cut, __last, __depth_limit, __comp);
__last = __cut;
}
}
最上边注释的翻译:这是排序例程的帮助程序函数。
在传参时,除了首尾迭代器和排序方式,还传了一个 std::__lg(__last - __first) * 2
,对应 __depth_limit
。
while
表示,当区间长度太小时,不进行排序。
_S_threshold
是一个由 enum
定义的数,好像是叫枚举类型。
当 __depth_limit
为 \(0\) 时,也就是迭代次数较多时,不使用 __introsort_loop
,而是使用 __partial_sort
(部分排序)。
然后通过 __unguarded_partition_pivot
,得到一个奇怪的位置(这个函数的翻译是无防护分区枢轴)。
然后递归处理这个奇怪的位置到末位置,再更新末位置,继续循环。
鉴于本人比较好奇无防护分区枢轴是什么,于是先看的 __unguarded_partition_pivot
。
__unguarded_partition_pivot
/// This is a helper function...
template<typename _RandomAccessIterator, typename _Compare>
inline _RandomAccessIterator
__unguarded_partition_pivot(_RandomAccessIterator __first,
_RandomAccessIterator __last, _Compare __comp)
{
_RandomAccessIterator __mid = __first + (__last - __first) / 2;
std::__move_median_to_first(__first, __first + 1, __mid, __last - 1,
__comp);
return std::__unguarded_partition(__first + 1, __last, __first, __comp);
}
首先,找到了中间点。
然后 __move_median_to_first
(把中间的数移到第一位)。
最后返回 __unguarded_partition
。
__move_median_to_first
这里的中间数,并不是数列的中间数,而是三个迭代器的中间值。
这三个迭代器分别指向:第二个数,中间的数,最后一个数。
/// Swaps the median value of *__a, *__b and *__c under __comp to *__result
template<typename _Iterator, typename _Compare>
void
__move_median_to_first(_Iterator __result,_Iterator __a, _Iterator __b,
_Iterator __c, _Compare __comp)
{
if (__comp(__a, __b))
{
if (__comp(__b, __c))
std::iter_swap(__result, __b);
else if (__comp(__a, __c))
std::iter_swap(__result, __c);
else
std::iter_swap(__result, __a);
}
else if (__comp(__a, __c))
std::iter_swap(__result, __a);
else if (__comp(__b, __c))
std::iter_swap(__result, __c);
else
std::iter_swap(__result, __b);
}
至于为什么取中间的数,暂时还不是很清楚。
__unguarded_partition
/// This is a helper function...
template<typename _RandomAccessIterator, typename _Compare>
_RandomAccessIterator
__unguarded_partition(_RandomAccessIterator __first,
_RandomAccessIterator __last,
_RandomAccessIterator __pivot, _Compare __comp)
{
while (true)
{
while (__comp(__first, __pivot))
++__first;
--__last;
while (__comp(__pivot, __last))
--__last;
if (!(__first < __last))
return __first;
std::iter_swap(__first, __last);
++__first;
}
}
传参传来的序列第二位到最后。
看着看着,我好像悟了。
这里应该就是实现快速排序的部分。
上边的 __move_median_to_first
是为了防止特殊数据卡 \(O(n^2)\)。经过移动的话,第一个位置就不会是最小值,放在左半序列的数也就不会为 \(0\)。
这样的话,__unguarded_partition
就是快排的主体。
那么,接下来该去看部分排序了。
__partial_sort
template<typename _RandomAccessIterator, typename _Compare>
inline void
__partial_sort(_RandomAccessIterator __first,
_RandomAccessIterator __middle,
_RandomAccessIterator __last,
_Compare __comp)
{
std::__heap_select(__first, __middle, __last, __comp);
std::__sort_heap(__first, __middle, __comp);
}
这里浅显的理解为堆排序,至于具体实现,在 stl_heap.h
里,不属于我们的讨论范围。
绝对不是因为我懒。
这样的话,__introsort_loop
就结束了。下一步就要回到 __sort
。
__final_insertion_sort
/// This is a helper function for the sort routine.
template<typename _RandomAccessIterator, typename _Compare>
void
__final_insertion_sort(_RandomAccessIterator __first,
_RandomAccessIterator __last, _Compare __comp)
{
if (__last - __first > int(_S_threshold))
{
std::__insertion_sort(__first, __first + int(_S_threshold), __comp);
std::__unguarded_insertion_sort(__first + int(_S_threshold), __last,
__comp);
}
else
std::__insertion_sort(__first, __last, __comp);
}
其中某常量为 enum { _S_threshold = 16 };
。
enum
,好像是枚举吧啦吧啦什么的,不了解。
其中实现的函数有两个:
__insertion_sort
插入排序。__unguarded_insertion_sort
无保护的插入排序(?
__insertion_sort
/// This is a helper function for the sort routine.
template<typename _RandomAccessIterator, typename _Compare>
void
__insertion_sort(_RandomAccessIterator __first,
_RandomAccessIterator __last, _Compare __comp)
{
if (__first == __last) return;
for (_RandomAccessIterator __i = __first + 1; __i != __last; ++__i)
{
if (__comp(__i, __first))
{
typename iterator_traits<_RandomAccessIterator>::value_type
__val = _GLIBCXX_MOVE(*__i);
_GLIBCXX_MOVE_BACKWARD3(__first, __i, __i + 1);
*__first = _GLIBCXX_MOVE(__val);
}
else
std::__unguarded_linear_insert(__i,
__gnu_cxx::__ops::__val_comp_iter(__comp));
}
}
其中的 __comp
依然按照默认排序方式 <
来理解。
_GLIBCXX_MOVE_BACKWARD3
进入到 _GLIBCXX_MOVE_BACKWARD3
,是一个神奇的 #define
:
#define _GLIBCXX_MOVE_BACKWARD3(_Tp, _Up, _Vp) std::move_backward(_Tp, _Up, _Vp)
#else
#define _GLIBCXX_MOVE_BACKWARD3(_Tp, _Up, _Vp) std::copy_backward(_Tp, _Up, _Vp)
#endif
其上就是 move_backward
:
/**
* @brief Moves the range [first,last) into result.
* @ingroup mutating_algorithms
* @param __first A bidirectional iterator.
* @param __last A bidirectional iterator.
* @param __result A bidirectional iterator.
* @return result - (first - last)
*
* The function has the same effect as move, but starts at the end of the
* range and works its way to the start, returning the start of the result.
* This inline function will boil down to a call to @c memmove whenever
* possible. Failing that, if random access iterators are passed, then the
* loop count will be known (and therefore a candidate for compiler
* optimizations such as unrolling).
*
* Result may not be in the range (first,last]. Use move instead. Note
* that the start of the output range may overlap [first,last).
*/
template<typename _BI1, typename _BI2>
inline _BI2
move_backward(_BI1 __first, _BI1 __last, _BI2 __result)
{
// concept requirements
__glibcxx_function_requires(_BidirectionalIteratorConcept<_BI1>)
__glibcxx_function_requires(_Mutable_BidirectionalIteratorConcept<_BI2>)
__glibcxx_function_requires(_ConvertibleConcept<
typename iterator_traits<_BI1>::value_type,
typename iterator_traits<_BI2>::value_type>)
__glibcxx_requires_valid_range(__first, __last);
return std::__copy_move_backward_a2<true>(std::__miter_base(__first),
std::__miter_base(__last),
__result);
}
上边的注释翻译为:
该函数与 move 具有相同的效果,但从范围的末尾开始,并一直运行到开始,返回结果的开头。
只要可能,这个内联函数将归结为对 @c memmove 的调用。否则,如果传递了随机访问迭代器,则循环计数将是已知的(因此是编译器优化(如展开)的候选项)。
结果可能不在范围 (first,last] 内。请改用 move。请注意,输出范围的开始可能与 [first,last) 重叠。
__unguarded_linear_insert
/// This is a helper function for the sort routine.
template<typename _RandomAccessIterator, typename _Compare>
void
__unguarded_linear_insert(_RandomAccessIterator __last,
_Compare __comp)
{
typename iterator_traits<_RandomAccessIterator>::value_type
__val = _GLIBCXX_MOVE(*__last);
_RandomAccessIterator __next = __last;
--__next;
while (__comp(__val, __next))
{
*__last = _GLIBCXX_MOVE(*__next);
__last = __next;
--__next;
}
*__last = _GLIBCXX_MOVE(__val);
}
翻译为“无防护线性插入”,应该是指直接插入吧。
当 __last
的值比前边元素的值小的时候,就一直进行交换,最后把 __last
放到对应的位置。
__unguarded_insertion_sort
/// This is a helper function for the sort routine.
template<typename _RandomAccessIterator, typename _Compare>
inline void
__unguarded_insertion_sort(_RandomAccessIterator __first,
_RandomAccessIterator __last, _Compare __comp)
{
for (_RandomAccessIterator __i = __first; __i != __last; ++__i)
std::__unguarded_linear_insert(__i,
__gnu_cxx::__ops::__val_comp_iter(__comp));
}
就是直接对区间的每个元素进行插入。
总结
到这里,sort
的源代码就剖完了(除了堆的那部分)。
虽然没怎么看懂,但也理解了,sort
的源码是在快排的基础上,通过堆排序和插入排序来维护时间复杂度的稳定,不至于退化为 \((n^2)\)。
鬼知道我写这么多是为了干嘛……