STL 源码剖析:sort

概览

我大抵是太闲了。

此文章同步发表于我的知乎

sort 作为最常用的 STL 之一,大多数人对于其了解仅限于快速排序。

听说其内部实现还包括插入排序和堆排序,于是很好奇,决定通过源代码一探究竟。

个人习惯使用 DEV-C++,不知道其他的编译器会不会有所不同,现阶段也不是很关心。

这个文章并不是析完之后的总结,而是边剖边写。不免有个人的猜测。而且由于本人英语极其差劲,大抵会犯一些憨憨错误。

源码部分

sort

首先,在 Dev 中输入以下代码:

#include<algorithm>
sort();

然后按住 ctrl,鼠标左键 sort,就可以跳转到头文件 stl_algo.h,并可以看到这个:

  /**
   *  @brief Sort the elements of a sequence.
   *  @ingroup sorting_algorithms
   *  @param  __first   An iterator.
   *  @param  __last    Another iterator.
   *  @return  Nothing.
   *
   *  Sorts the elements in the range @p [__first,__last) in ascending order,
   *  such that for each iterator @e i in the range @p [__first,__last-1),  
   *  *(i+1)<*i is false.
   *
   *  The relative ordering of equivalent elements is not preserved, use
   *  @p stable_sort() if this is needed.
  */
  template<typename _RandomAccessIterator>
    inline void
    sort(_RandomAccessIterator __first, _RandomAccessIterator __last)
    {
      // concept requirements
      __glibcxx_function_requires(_Mutable_RandomAccessIteratorConcept<
	    _RandomAccessIterator>)
      __glibcxx_function_requires(_LessThanComparableConcept<
	    typename iterator_traits<_RandomAccessIterator>::value_type>)
      __glibcxx_requires_valid_range(__first, __last);

      std::__sort(__first, __last, __gnu_cxx::__ops::__iter_less_iter());
    }

  /**
   *  @brief Sort the elements of a sequence using a predicate for comparison.
   *  @ingroup sorting_algorithms
   *  @param  __first   An iterator.
   *  @param  __last    Another iterator.
   *  @param  __comp    A comparison functor.
   *  @return  Nothing.
   *
   *  Sorts the elements in the range @p [__first,__last) in ascending order,
   *  such that @p __comp(*(i+1),*i) is false for every iterator @e i in the
   *  range @p [__first,__last-1).
   *
   *  The relative ordering of equivalent elements is not preserved, use
   *  @p stable_sort() if this is needed.
  */
  template<typename _RandomAccessIterator, typename _Compare>
    inline void
    sort(_RandomAccessIterator __first, _RandomAccessIterator __last,
	 _Compare __comp)
    {
      // concept requirements
      __glibcxx_function_requires(_Mutable_RandomAccessIteratorConcept<
	    _RandomAccessIterator>)
      __glibcxx_function_requires(_BinaryPredicateConcept<_Compare,
	    typename iterator_traits<_RandomAccessIterator>::value_type,
	    typename iterator_traits<_RandomAccessIterator>::value_type>)
      __glibcxx_requires_valid_range(__first, __last);

      std::__sort(__first, __last, __gnu_cxx::__ops::__iter_comp_iter(__comp));
    }

注释、模板和函数参数不再解释,我们需要关注的是函数体。

但是,中间那一段没看懂……

点进去,是一堆看不懂的 #define

查了一下,感觉这东西不是我这个菜鸡能掌握的。

有兴趣的戳这里

那么接下来,就应该去到函数 __sort 来一探究竟了。

__sort

通过同样的方法,继续在 stl_algo.h 里找到 __sort 的源代码。

  template<typename _RandomAccessIterator, typename _Compare>
    inline void
    __sort(_RandomAccessIterator __first, _RandomAccessIterator __last,
	   _Compare __comp)
    {
      if (__first != __last)
	{
	  std::__introsort_loop(__first, __last,
				std::__lg(__last - __first) * 2,
				__comp);
	  std::__final_insertion_sort(__first, __last, __comp);
	}
    }

同样,只看函数体部分。

一般来说,sort(a,a+n) 是对于区间 \([0,n)\) 进行排序,所以排序的前提是 __first != __last

如果能排序,那么通过两种方式:

  • __introsort_loop 翻译过来,好像叫内联循环

  • __final_insertion_sort 应该叫末端插入排序吧。

一部分一部分的看。

__introsort_loop

  /// This is a helper function for the sort routine.
  template<typename _RandomAccessIterator, typename _Size, typename _Compare>
    void
    __introsort_loop(_RandomAccessIterator __first,
		     _RandomAccessIterator __last,
		     _Size __depth_limit, _Compare __comp)
    {
      while (__last - __first > int(_S_threshold))
	{
	  if (__depth_limit == 0)
	    {
	      std::__partial_sort(__first, __last, __last, __comp);
	      return;
	    }
	  --__depth_limit;
	  _RandomAccessIterator __cut =
	    std::__unguarded_partition_pivot(__first, __last, __comp);
	  std::__introsort_loop(__cut, __last, __depth_limit, __comp);
	  __last = __cut;
	}
    }

最上边注释的翻译:这是排序例程的帮助程序函数。

在传参时,除了首尾迭代器和排序方式,还传了一个 std::__lg(__last - __first) * 2,对应 __depth_limit

while 表示,当区间长度太小时,不进行排序。

_S_threshold 是一个由 enum 定义的数,好像是叫枚举类型。

__depth_limit\(0\) 时,也就是迭代次数较多时,不使用 __introsort_loop,而是使用 __partial_sort(部分排序)。

然后通过 __unguarded_partition_pivot,得到一个奇怪的位置(这个函数的翻译是无防护分区枢轴)。

然后递归处理这个奇怪的位置到末位置,再更新末位置,继续循环。

鉴于本人比较好奇无防护分区枢轴是什么,于是先看的 __unguarded_partition_pivot

__unguarded_partition_pivot

  /// This is a helper function...
  template<typename _RandomAccessIterator, typename _Compare>
    inline _RandomAccessIterator
    __unguarded_partition_pivot(_RandomAccessIterator __first,
				_RandomAccessIterator __last, _Compare __comp)
    {
      _RandomAccessIterator __mid = __first + (__last - __first) / 2;
      std::__move_median_to_first(__first, __first + 1, __mid, __last - 1,
				  __comp);
      return std::__unguarded_partition(__first + 1, __last, __first, __comp);
    }

首先,找到了中间点。

然后 __move_median_to_first(把中间的数移到第一位)。

最后返回 __unguarded_partition

__move_median_to_first

这里的中间数,并不是数列的中间数,而是三个迭代器的中间值。

这三个迭代器分别指向:第二个数,中间的数,最后一个数。

  /// Swaps the median value of *__a, *__b and *__c under __comp to *__result
  template<typename _Iterator, typename _Compare>
    void
    __move_median_to_first(_Iterator __result,_Iterator __a, _Iterator __b,
			   _Iterator __c, _Compare __comp)
    {
      if (__comp(__a, __b))
	{
	  if (__comp(__b, __c))
	    std::iter_swap(__result, __b);
	  else if (__comp(__a, __c))
	    std::iter_swap(__result, __c);
	  else
	    std::iter_swap(__result, __a);
	}
      else if (__comp(__a, __c))
	std::iter_swap(__result, __a);
      else if (__comp(__b, __c))
	std::iter_swap(__result, __c);
      else
	std::iter_swap(__result, __b);
    }

至于为什么取中间的数,暂时还不是很清楚。

__unguarded_partition
  /// This is a helper function...
  template<typename _RandomAccessIterator, typename _Compare>
    _RandomAccessIterator
    __unguarded_partition(_RandomAccessIterator __first,
			  _RandomAccessIterator __last,
			  _RandomAccessIterator __pivot, _Compare __comp)
    {
      while (true)
	{
	  while (__comp(__first, __pivot))
	    ++__first;
	  --__last;
	  while (__comp(__pivot, __last))
	    --__last;
	  if (!(__first < __last))
	    return __first;
	  std::iter_swap(__first, __last);
	  ++__first;
	}
    }

传参传来的序列第二位到最后。

看着看着,我好像悟了。

这里应该就是实现快速排序的部分。

上边的 __move_median_to_first 是为了防止特殊数据卡 \(O(n^2)\)。经过移动的话,第一个位置就不会是最小值,放在左半序列的数也就不会为 \(0\)

这样的话,__unguarded_partition 就是快排的主体。

那么,接下来该去看部分排序了。

__partial_sort

  template<typename _RandomAccessIterator, typename _Compare>
    inline void
    __partial_sort(_RandomAccessIterator __first,
		   _RandomAccessIterator __middle,
		   _RandomAccessIterator __last,
		   _Compare __comp)
    {
      std::__heap_select(__first, __middle, __last, __comp);
      std::__sort_heap(__first, __middle, __comp);
    }

这里浅显的理解为堆排序,至于具体实现,在 stl_heap.h 里,不属于我们的讨论范围。

绝对不是因为我懒。

这样的话,__introsort_loop 就结束了。下一步就要回到 __sort

__final_insertion_sort

  /// This is a helper function for the sort routine.
  template<typename _RandomAccessIterator, typename _Compare>
    void
    __final_insertion_sort(_RandomAccessIterator __first,
			   _RandomAccessIterator __last, _Compare __comp)
    {
      if (__last - __first > int(_S_threshold))
	{
	  std::__insertion_sort(__first, __first + int(_S_threshold), __comp);
	  std::__unguarded_insertion_sort(__first + int(_S_threshold), __last,
					  __comp);
	}
      else
	std::__insertion_sort(__first, __last, __comp);
    }

其中某常量为 enum { _S_threshold = 16 };

enum,好像是枚举吧啦吧啦什么的,不了解。

其中实现的函数有两个:

  • __insertion_sort 插入排序。
  • __unguarded_insertion_sort 无保护的插入排序(?

__insertion_sort

  /// This is a helper function for the sort routine.
  template<typename _RandomAccessIterator, typename _Compare>
    void
    __insertion_sort(_RandomAccessIterator __first,
		     _RandomAccessIterator __last, _Compare __comp)
    {
      if (__first == __last) return;

      for (_RandomAccessIterator __i = __first + 1; __i != __last; ++__i)
	{
	  if (__comp(__i, __first))
	    {
	      typename iterator_traits<_RandomAccessIterator>::value_type
		__val = _GLIBCXX_MOVE(*__i);
	      _GLIBCXX_MOVE_BACKWARD3(__first, __i, __i + 1);
	      *__first = _GLIBCXX_MOVE(__val);
	    }
	  else
	    std::__unguarded_linear_insert(__i,
				__gnu_cxx::__ops::__val_comp_iter(__comp));
	}
    }

其中的 __comp 依然按照默认排序方式 < 来理解。

_GLIBCXX_MOVE_BACKWARD3

进入到 _GLIBCXX_MOVE_BACKWARD3,是一个神奇的 #define

#define _GLIBCXX_MOVE_BACKWARD3(_Tp, _Up, _Vp) std::move_backward(_Tp, _Up, _Vp)
#else
#define _GLIBCXX_MOVE_BACKWARD3(_Tp, _Up, _Vp) std::copy_backward(_Tp, _Up, _Vp)
#endif

其上就是 move_backward

  /**
   *  @brief Moves the range [first,last) into result.
   *  @ingroup mutating_algorithms
   *  @param  __first  A bidirectional iterator.
   *  @param  __last   A bidirectional iterator.
   *  @param  __result A bidirectional iterator.
   *  @return   result - (first - last)
   *
   *  The function has the same effect as move, but starts at the end of the
   *  range and works its way to the start, returning the start of the result.
   *  This inline function will boil down to a call to @c memmove whenever
   *  possible.  Failing that, if random access iterators are passed, then the
   *  loop count will be known (and therefore a candidate for compiler
   *  optimizations such as unrolling).
   *
   *  Result may not be in the range (first,last].  Use move instead.  Note
   *  that the start of the output range may overlap [first,last).
  */
  template<typename _BI1, typename _BI2>
    inline _BI2
    move_backward(_BI1 __first, _BI1 __last, _BI2 __result)
    {
      // concept requirements
      __glibcxx_function_requires(_BidirectionalIteratorConcept<_BI1>)
      __glibcxx_function_requires(_Mutable_BidirectionalIteratorConcept<_BI2>)
      __glibcxx_function_requires(_ConvertibleConcept<
	    typename iterator_traits<_BI1>::value_type,
	    typename iterator_traits<_BI2>::value_type>)
      __glibcxx_requires_valid_range(__first, __last);

      return std::__copy_move_backward_a2<true>(std::__miter_base(__first),
						std::__miter_base(__last),
						__result);
    }

上边的注释翻译为:

该函数与 move 具有相同的效果,但从范围的末尾开始,并一直运行到开始,返回结果的开头。
只要可能,这个内联函数将归结为对 @c memmove 的调用。否则,如果传递了随机访问迭代器,则循环计数将是已知的(因此是编译器优化(如展开)的候选项)。
结果可能不在范围 (first,last] 内。请改用 move。请注意,输出范围的开始可能与 [first,last) 重叠。
__unguarded_linear_insert
  /// This is a helper function for the sort routine.
  template<typename _RandomAccessIterator, typename _Compare>
    void
    __unguarded_linear_insert(_RandomAccessIterator __last,
			      _Compare __comp)
    {
      typename iterator_traits<_RandomAccessIterator>::value_type
	__val = _GLIBCXX_MOVE(*__last);
      _RandomAccessIterator __next = __last;
      --__next;
      while (__comp(__val, __next))
	{
	  *__last = _GLIBCXX_MOVE(*__next);
	  __last = __next;
	  --__next;
	}
      *__last = _GLIBCXX_MOVE(__val);
    }

翻译为“无防护线性插入”,应该是指直接插入吧。

__last 的值比前边元素的值小的时候,就一直进行交换,最后把 __last 放到对应的位置。

__unguarded_insertion_sort

  /// This is a helper function for the sort routine.
  template<typename _RandomAccessIterator, typename _Compare>
    inline void
    __unguarded_insertion_sort(_RandomAccessIterator __first,
			       _RandomAccessIterator __last, _Compare __comp)
    {
      for (_RandomAccessIterator __i = __first; __i != __last; ++__i)
	std::__unguarded_linear_insert(__i,
				__gnu_cxx::__ops::__val_comp_iter(__comp));
    }

就是直接对区间的每个元素进行插入。

总结

到这里,sort 的源代码就剖完了(除了堆的那部分)。

虽然没怎么看懂,但也理解了,sort 的源码是在快排的基础上,通过堆排序和插入排序来维护时间复杂度的稳定,不至于退化为 \((n^2)\)

鬼知道我写这么多是为了干嘛……

posted @ 2022-11-17 20:45  Zvelig1205  阅读(343)  评论(0编辑  收藏  举报