java8中List中sort方法解析
概述
集合类中的sort方法,听说在java7中就引入了,但是我没有用过java7,不太清楚,java8中的排序是采用Timsort排序算法实现的,这个排序最开始是在python中由Tim Peters实现的,后来Java觉得不错,就引入了这个排序到Java中,竟然以作者的名字命名,搞得我还以为这个Tim是一个单词的意思,了不起,本文就从Arrays中实现的排序分析一下这个排序算法的原理,本文只会从源码角度分析,不会从算法角度去分析。
进入List中查看sort方法源码如下:
default void sort(Comparator<? super E> c) { Object[] a = this.toArray(); // 这个方法很简单,就是调用Arrays中的sort方法进行排序 Arrays.sort(a, (Comparator) c); ListIterator<E> i = this.listIterator(); for (Object e : a) { i.next(); i.set((E) e); } }
进入Arrays.sort()方法
public static <T> void sort(T[] a, Comparator<? super T> c) { //这个是自己传入的比较器如果为空,这里的a为存放数据的数组,那数组中的的元素都必须实现Comparator接口 if (c == null) { sort(a); } else {
//这个判断没有细看,不太清楚在判断什么 if (LegacyMergeSort.userRequested) legacyMergeSort(a, c); else //这里就是所谓的TimSort了 TimSort.sort(a, 0, a.length, c, null, 0, 0); } }
由于sort()和TimSort.sort走的流程基本一致,这里只分析TimSort.sort()方法,进入该方法。
这里有必要先说一下TimSort排序算法的核心内容,了解这个算法的核心内容有助于看下面的代码。
TimSort的核心是这样:
1.如果数组的长度小于32,直接采用二分法插入排序,就是下面方法中的binarySort方法实现的,这个算法原理,我举个例子大家就明白了
假设数组为:[1,3,9,6,2],二分法插入排序插入如下:
I:从开头先把自然升序(或降序)段找出来,那什么是自然升序段,就是没有经过排序算法,原始数据就是有序的,本数组中自然升序段就是1,3,9
II:按照正常的思维,直接拿6,从头开始和前三个元素一个一个比较也可以实现排序,但是这样效率太低,那怎么做可以效率高点呢?就是我们之前高数中学的二分查找法,就是通过二分查找法,我先找到3,发现6 > 3,那我就不用和1进行比较了。
2.如果数组的长度大于32,那就把数组拆分成一个一个的小段,每段的长度在16~32之间,使用上面介绍的二分法插入排序,把每一段进行排序,之后在把每一个排好序的段进行合并,最终就可以实现整个数组的排序,大致的思想就是这样,这个叙述可能会给大家一种误解,就是认为每一段都排好序之后在进行合并,其实不是这样的,而是每一段边排序,如果符合特定条件就会合并。
有了上面的了解,我们再来看下面的代码
static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c, T[] work, int workBase, int workLen) { assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length; //这里是数组中剩余没有排序的元素个数,初始长度为数组的长度 int nRemaining = hi - lo; if (nRemaining < 2) return; // Arrays of size 0 and 1 are always sorted //这里的MIN_MERGE就是32,如果数组长度小于32,直接采用二分法插入排序 // If array is small, do a "mini-TimSort" with no merges if (nRemaining < MIN_MERGE) { int initRunLen = countRunAndMakeAscending(a, lo, hi, c); binarySort(a, lo, hi, lo + initRunLen, c); return; } /** * March over the array once, left to right, finding natural runs, * extending short natural runs to minRun elements, and merging runs * to maintain stack invariant. */ //这个是TimSort核心类,很多处理逻辑都是这个里面 TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen); //<1.1> 这个就是计算分割之后每一段的长度的 int minRun = minRunLength(nRemaining); do { // Identify next run //<1.2> 寻找自然增长的结束位置 int runLen = countRunAndMakeAscending(a, lo, hi, c); // If run is short, extend to min(minRun, nRemaining) if (runLen < minRun) { int force = nRemaining <= minRun ? nRemaining : minRun; //<1.3> 对每一段进行二分法插入排序 binarySort(a, lo, lo + force, lo + runLen, c); runLen = force; } // Push run onto pending-run stack, and maybe merge,将每一段的起始位置和每一段的分段长度放入栈中 ts.pushRun(lo, runLen); //<1.4> 合并排好序的段,这个就是上面我说的并不是等所有的都排好序了再合并 ts.mergeCollapse(); // Advance to find next run lo += runLen;
// 将已经排好序的段,从总长度中减去 nRemaining -= runLen; } while (nRemaining != 0); // Merge all remaining runs to complete sort assert lo == hi; //<1.5> 最终排序,这个方法在整个排序中只会执行一次 ts.mergeForceCollapse(); assert ts.stackSize == 1; }
上面注释中<1.1>, minRunLength(nRemaining)
private static int minRunLength(int n) { assert n >= 0; int r = 0; // Becomes 1 if any 1 bits are shifted off while (n >= MIN_MERGE) {
//这一段有点绕,其实意思就是当n是奇数的时候r = 1 r |= (n & 1);
//这段的意思就是一直除于2,直到n < 32为止 n >>= 1; }
//如果n是偶数,结果就是一直除于2的结果,如果是奇数,就是一直除于二加一 return n + r; }
上面注释<1.2>, countRunAndMakeAscending(a, lo, hi, c);
//解释一下各个参数:a就是存放元素的数组,lo第一个元素的位置(注意这个第一个元素并不一定是数组的第一个元素的位置,而是每一段的第一个元素),hi表示数组的长度,c为比较器 private static <T> int countRunAndMakeAscending(T[] a, int lo, int hi, Comparator<? super T> c) { assert lo < hi; int runHi = lo + 1; if (runHi == hi) return 1; // Find end of run, and reverse range if descending //下面的if...else就是寻找自增序列的,if中判断的情况是寻找自然降序的 if (c.compare(a[runHi++], a[lo]) < 0) { // Descending while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0) runHi++;
//找到降序的段之后,进行反转成升序 reverseRange(a, lo, runHi); } else { // Ascending,前面的英文是原来的注释,可以看到这个是找升序段的位置的 while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0) runHi++; } //最后返回的结果其实第一个违反将序或升序的元素位置减去第一个元素的位置,举例[1,3,5,2,4],那么runHi=3,lo = 0,最后返回3 return runHi - lo; }
这个方法其实就是上面我在说TimSort原理的时候讲到的,寻找自然升序或将序段,这么做的原因是减小排序时候元素的个数,加快排序速度。
上面注释<1.3>,binarySort(a, lo, lo + force, lo + runLen, c);这个方法是核心排序方法,使用的是二分法插入排序算法
//先解释一下各个参数:a为存放元素的数组,lo是各个分段的起始位置,hi为数组的长度,start就是coutRunAndMakeAsending()方法返回的结果加上起始结果 private static <T> void binarySort(T[] a, int lo, int hi, int start, Comparator<? super T> c) { assert lo <= start && start <= hi; if (start == lo) start++; for ( ; start < hi; start++) { // 备份start位置的值,因为这个直后面可能被覆盖掉 T pivot = a[start]; // Set left (and right) to the index where a[start] (pivot) belongs int left = lo; int right = start; assert left <= right; /* * Invariants: * pivot >= all in [lo, left). * pivot < all in [right, start). */ //下面这个while循环就是一个二分查找法的过程,先确定二分查找法的范围,就是left和rigth,之后每次找到left和rigth的中间点,那后比较 while (left < right) {
//寻找中间点 int mid = (left + right) >>> 1;
//比较大小 if (c.compare(pivot, a[mid]) < 0) right = mid; else left = mid + 1; } assert left == right; /* * The invariants still hold: pivot >= all in [lo, left) and * pivot < all in [left, start), so pivot belongs at left. Note * that if there are elements equal to pivot, left points to the * first slot after them -- that's why this sort is stable. * Slide elements over to make room for pivot. */ int n = start - left; // The number of elements to move // Switch is just an optimization for arraycopy in default case ,这个switch case用的非常讲究,当你明白了这个玩意,你就不得不佩服大佬,看看真正的大佬是如何把普通的东西玩出不一样 switch (n) { case 2: a[left + 2] = a[left + 1]; case 1: a[left + 1] = a[left]; break; default: System.arraycopy(a, left, a, left + 1, n); } a[left] = pivot; } }
这个方法总的来说还是很好懂的,就是switch case那一块用的很牛逼,下面我就说一下这一块。我会对case对应的每一种情况举一个例子大家就明白这里为什么吊了。
case n = 2: 假设a = [1,2,5,8,9,6],当使用二分查找法定位的时候一定可以定位到5后面的位置,也就是a[3],这个时候要怎么做呢,这时left = 3,那a[left + 2] = a[left + 1];就是a[5] = a[4],就是把9放到6的位置。这还没有完,因为这个case后面没有continue也没有break,会继续执行后面的case,也就是会执行a[left + 1] = a[left];就是a[4] = a[3],相当于把8放到9的位置,这个时候来了一个break,后面的default就不会执行了,但是会执行a[left] = pivot;这句相当于a[3] = 6,这样就实现了一个交换动作(如果对switch...case中使用break不熟悉的建议先查一下这个)。
case n = 1: 假设a = [1,2,5,4],这时left = 2,执行a[left + 1] = a[left]相当于a[3] = a[2] = 5;最后执行a[left] = pivot; 就是a[2] = 4,这样做完之后就变成了[1,2,4,5]
default: 就是n > 2的时候,思考一下这里n不可能为0,因为n为0说明就是正常的升序,这个在前一个方法寻找自然自增序列的时候已经处理了这种情况,那就是当n > 2时候,举例如下:
a = [1,3,4,4,5,2],这种时候n = 4如果还是一个元素一个元素移动,那效率太低了,这时候使用了System.arraycopy方法,关于这个方法我写个例子大家看一下就知道这个方法可以干什么了。
public static void copyArray(){ Integer[] a = {1,3,4,4,5,2};
//关于这个参数大家可以看一下源码 System.arraycopy(a,1,a,2,4); for (int i = 0; i < a.length; i++) { System.out.println(a[i]); } } public static void main(String[] args) { copyArray(); }
输出结果为:1 3 3 4 4 5
可以看出这个方法的作用其实就是把数组元素下标1到4的元素拷贝到2到5,最后一个参数4表示拷贝的元素个数,可以发现最后一个元素2被覆盖了,细心的朋友可能发现这个结果也不是我们想要的排序结果啊,我们想要的是1 2 3 4 4 5,而现在是1 3 3 4 4 5,别急不是还有a[left] = pivot;这个方法吗,这个方法就是a[1] = 2,这样就完美了。简直是小母牛去南极,牛逼到了极点。
上面注释1.4,ts.mergeCollapse(),合并排好序的分段,再说这个方法之前有必要先说一下TimSort方法定义的几个属性,这几个属性会在下面的方法中用到
class TimSort<T> { private static final int MIN_GALLOP = 7; private T[] tmp; // private int tmpBase; // base of tmp array slice private int tmpLen; // length of tmp array slice private int stackSize = 0; // Number of pending runs on stack //这个里面存放就是每段待合并的分段的在数组的开始位置 private final int[] runBase; //这里面存放的是每个分段的长度,和上面的runBase是一一对应的 private final int[] runLen; }
ok,有了上面的认识,来开始看下面的代码
private void mergeCollapse() { while (stackSize > 1) { int n = stackSize - 2; //这个if判断的意义就是后面两段的长度之和一定要大于前面一段的长度才会执行 if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) { if (runLen[n - 1] < runLen[n + 1]) n--;
// <2.1>执行合并逻辑 mergeAt(n); //如果上面的不满足,会执行这个,这个意思就是后面一段的长度要大于前面一段的长度 } else if (runLen[n] <= runLen[n + 1]) { mergeAt(n); //否则就不进行合并 } else { break; // Invariant is established } } }
这个方法主要的作用就是为了防止太短的段和比较长的段进行合并,浪费时间,举个例子现在runLen = [256 128 20 8],如果让8和20合并还可以接受,但是让20和128合并就太浪费时间了,至于说小段和大段合并为什么比较浪费,后面会分析,不过这里要做一个补充,其实在这里即便不做合并,后面还是会合并的,只是后面那个合并是一次性的,就是一个方法把所有的未合并分段全部合并了。
注释<2.1>,下面进入mergeAt(n);
private void mergeAt(int i) { assert stackSize >= 2; assert i >= 0; assert i == stackSize - 2 || i == stackSize - 3; int base1 = runBase[i]; int len1 = runLen[i]; int base2 = runBase[i + 1]; int len2 = runLen[i + 1]; assert len1 > 0 && len2 > 0; assert base1 + len1 == base2; /* * Record the length of the combined runs; if i is the 3rd-last * run now, also slide over the last run (which isn't involved * in this merge). The current run (i+1) goes away in any case. */ runLen[i] = len1 + len2; if (i == stackSize - 3) { runBase[i + 1] = runBase[i + 2]; runLen[i + 1] = runLen[i + 2]; } stackSize--; /* * Find where the first element of run2 goes in run1. Prior elements * in run1 can be ignored (because they're already in place). */ // int k = gallopRight(a[base2], a, base1, len1, 0, c); assert k >= 0; base1 += k; len1 -= k; if (len1 == 0) return; /* * Find where the last element of run1 goes in run2. Subsequent elements * in run2 can be ignored (because they're already in place). */ // len2 = gallopLeft(a[base1 + len1 - 1], a, base2, len2, len2 - 1, c); assert len2 >= 0; if (len2 == 0) return; // Merge remaining runs, using tmp array with min(len1, len2) elements if (len1 <= len2) mergeLo(base1, len1, base2, len2); else
//<3.1> 合并
mergeHi(base1, len1, base2, len2); }
这个方法其实也没有执行合并的逻辑,那这个方法在干啥呢?其实这个方法还是在缩短比较的段的长度,其中两个主要的方法就是gallopRigth()和gallopLeft(),这个两个方法是在干啥呢?我下面还是举例说明吧。
假设:
第一个段为:[1,2,3,5,6,8,9]
第二个段为:[4,6,7,8,10,11,12]
第一个段在数组中的位置在第二个段前面,这里注意实际的段不可能这么短,上面有说段的长度应该在16到32之间,这里只是举例为了说明问题。
gallopRigth(): 寻找第二段的第一个元素在第一段中的位置,比如例子中位置为2,那也就是说第一段的前两个元素没有必要参与合并,他的位置不用动。
gallopLeft(): 寻找第一段的结尾元素在第二段中的位置,这里发现在第二段的第4,那也就是说第二段的10,11,12没有必要参与合并,同样是位置不需要改动。
最终参与合并的段为:
第一段:[5,6,8,9]
第二段:[4, 6, 7, 8]
这样参与合并的段的长度就大大减小,时间相应的就变短了,可能细心的小伙伴到这里就有一个疑问了,gallopRigth()是寻找第二段的第一个元素在第一段中的位置,而不是反过来,我觉得能想到这个疑问的朋友应该稍微思考一下就知道这个原因了,我就不多嘴了。所以gallopRigth()和gallopLeft()的源码我就不分析了,有兴趣的可以自己去看,里面还有些细节没有写到。
注释<3.1>,mergeHi(base1, len1, base2, len2);这个方法和注释<3.1>,mergeLo(base1, len1, base2, len2);是类似的方法,分析其中一个就可以了
//解释一下参数: base1 = 第一段的开始位置,len1 = 第一段的长度,base2 = 第二段的开始位置,len2 = 第二段的长度
private void mergeHi(int base1, int len1, int base2, int len2) { assert len1 > 0 && len2 > 0 && base1 + len1 == base2; // Copy second run into temp array T[] a = this.a; // For performance
//这里建立了一个空数组,目的是为了存放第二段的数据 T[] tmp = ensureCapacity(len2);
//这里的temBase是TimSort定义的一个属性,在TimSort初始化的时候给了一个初始化值0 int tmpBase = this.tmpBase;
//这里就是给上面新建的空数组赋值的,就是第二段放入到这个临时数组中 System.arraycopy(a, base2, tmp, tmpBase, len2); //下面定义两个游标,控制每一段比较的位置 int cursor1 = base1 + len1 - 1; // Indexes into a int cursor2 = tmpBase + len2 - 1; // Indexes into tmp array
//这个就是第二段的结束位置,这两段进行比较的时候也是从末尾开始比较,这里就是记录两段中比较大的元素会放入到这个位置上 int dest = base2 + len2 - 1; // Indexes into a // Move last element of first run and deal with degenerate cases
//从这一句就可以看出实现这个的作者的细致,解释一下这句,cursor1是第一段的结束位置,dest是第二段的结束位置,第一段的结束位置的值一定大于第二段的结束位置的值
//至于原因就是上面我分析的gallopLeft()方法和gollopRight() a[dest--] = a[cursor1--];
//这一句的意思就是如果len1 = 1,说明第二段都应该放入到第一段这个值的前面,原因还是上面的那个原因 if (--len1 == 0) { System.arraycopy(tmp, tmpBase, a, dest - (len2 - 1), len2); return; }
//这一句的意思就是如果第二段的长度为1,那就把他放入到第一段的前面 if (len2 == 1) { dest -= len1; cursor1 -= len1; System.arraycopy(a, cursor1 + 1, a, dest + 1, len1); a[dest] = tmp[cursor2]; return; } Comparator<? super T> c = this.c; // Use local variable for performance int minGallop = this.minGallop; // 这个minGallop = 7,是默认值,为啥搞这个值,看下面就知道了 outer: while (true) {
int count1 = 0; // Number of times in a row that first run won,其实这个count1和count2非常有意思,就是记录第一段中连续比第二段大的数的个数,注意是连续 int count2 = 0; // Number of times in a row that second run won,这个就是记录第二段中连续比第一段大的数字的个数 /* * Do the straightforward thing until (if ever) one run * appears to win consistently.
* 下面有两个do...while循环,其实是可以使用一个do...while循环实现的,但是作者为了优化,搞了两个do...while循环,下面我先说一下这个循环干了啥,为什么要搞两个
* 使用do...while循环的作用就是分别从第一段的最后一个数和第二段的最后一个数做比较,比较大小之后,谁比较大就放在dist的位置,这个位置其实就是从第二段结尾的位置逐渐减小
* 接下来说一下为什么要搞两个do...while循环,如果在这个比较当中发现count1 > 7或者count2 > 7,说明什么,说明第一段中有连续7个值大于第二段的未被比较的最大值,那就说明可能存在更多的值
* 大于第二段中的未被比较的值中的最大值,所以呢,他又调用了gallopLeft()和gallopRigth()把不需要比较的找出来进一步缩短合并的段的大小
* 看了上看的分析,是不是觉得作者非常牛逼,简直是小母牛掉进酒缸,醉牛逼 */ do { assert len1 > 0 && len2 > 1;
//比较第一段的最后一个元素和第二段的最后一个元素的大小 if (c.compare(tmp[cursor2], a[cursor1]) < 0) { a[dest--] = a[cursor1--]; count1++; count2 = 0; if (--len1 == 0) break outer; } else { a[dest--] = tmp[cursor2--]; count2++; count1 = 0; if (--len2 == 1) break outer; }
//如果count1 >= 7 或者count2 = 7,跳出循环,进入下一个循环中,把不需要合并的剔除 } while ((count1 | count2) < minGallop); /* * One run is winning so consistently that galloping may be a * huge win. So try that, and continue galloping until (if ever) * neither run appears to be winning consistently anymore.
* 这个do...while循环有点神奇,就是当count1>=7或者count2 >=7的时候就在这个里面实现合并,而不重新跳回第一个do...while循环
*
* */ do { assert len1 > 0 && len2 > 1;
//是不是又看到这个熟悉的方法 count1 = len1 - gallopRight(tmp[cursor2], a, base1, len1, len1 - 1, c); if (count1 != 0) { dest -= count1; cursor1 -= count1; len1 -= count1; System.arraycopy(a, cursor1 + 1, a, dest + 1, count1); if (len1 == 0) break outer; }
//这个do...while循环的合并就是采用这一句,每次这个合并完之后,重新去执行gallopRight或者gallopLeft方法,重新把不用合并的剔除掉 a[dest--] = tmp[cursor2--]; if (--len2 == 1) break outer; count2 = len2 - gallopLeft(a[cursor1], tmp, tmpBase, len2, len2 - 1, c); if (count2 != 0) { dest -= count2; cursor2 -= count2; len2 -= count2; System.arraycopy(tmp, cursor2 + 1, a, dest + 1, count2); if (len2 <= 1) // len2 == 1 || len2 == 0 break outer; }
//这个和上面类似,就是使用这个进行合并 a[dest--] = a[cursor1--]; if (--len1 == 0) break outer;
//这里的minGallop初始值是7,在这循环中,每循环一次就减1 minGallop--;
//如果发现第一段或者第二段的长度小于7了,就跳出这个循环 } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); if (minGallop < 0) minGallop = 0;
//重新给minGallop一个新值,跳回到第一个do...while中 minGallop += 2; // Penalize for leaving gallop mode } // End of "outer" loop this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field if (len2 == 1) { assert len1 > 0; dest -= len1; cursor1 -= len1; System.arraycopy(a, cursor1 + 1, a, dest + 1, len1); a[dest] = tmp[cursor2]; // Move first elt of run2 to front of merge } else if (len2 == 0) { throw new IllegalArgumentException( "Comparison method violates its general contract!"); } else { assert len1 == 0; assert len2 > 0; System.arraycopy(tmp, tmpBase, a, dest - (len2 - 1), len2); } }
总结:其实合并的过程就是在这个两个do...while之间来回跳的过程,而第二个do...while循环其实是对合并的一个优化,即便没有第二个循环也可以完成合并操作,不过要修改一下第一个循环的条件,而第二个循环是怎么优化的呢?这里就是作者的一个重要的思考了,就是当第二段的值连续大于第一段的某个值7次,是不是可以认为第二段中有可能有更多的值大于第一段呢?我觉得这个推断完全是正确的,做了这个优化之后就可以减少很多需要合并的值,这就是作者的厉害之处。
分析完以上合并过程,其实并没有完,为什么?因为上面1.4中说,并不是随便两个相邻的段都可以合并,而要满主一定的条件才可以合并,满足什么条件呢?其实上面已经说了,这里在重复一遍。
假设:连续的三段的长度x,y,z只要满足如下条件就合并:
x <= y + z. || y <=z
那相反的就是:
x > y+z. && y >z
也就是说满足上面条件的段就不会使用上面的方法进行合并,那这些个没有合并的段在哪里合并的呢,在下面的代码中合并。
注释<1.5>,ts.mergeForceCollapse();执行最终的合并
private void mergeForceCollapse() { while (stackSize > 1) { int n = stackSize - 2; if (n > 0 && runLen[n - 1] < runLen[n + 1]) n--; mergeAt(n); } }
这个里面stackSize就是栈的深度,其实就是没有合并的的段的多少,如果stackSize = 1说明什么,说明就只有一段了,那就说明已经合并完成了,至于mergeAt(n),这个方法我在上面已经介绍过了。
最后总结
这个排序算法其实还是很有必要看一下的,写的很有意思,里面做了大量的优化,从这些优化中我们可以学习到很多的东西,学到什么东西呢?可以看到这些大佬是怎么思考的,是怎么做事情的,在看这个代码的过程中我发现实现这个代码的作者真是非常的细致,每个能优化的点都考虑的非常清楚,简直是小母牛做钢锯,巨牛逼。另外在参考文章我也推荐大家好好看看,这篇文章的作者也很有意思,里面基本把上面的代码的过程给写出来了,只是没有把代码贴出来,感谢前辈。
参考文章: