很多人都说.net排序的效率高,抱着学习的态度观摩.net源代码。当数据类型为基本类型并且未指定排序规则时用的是本地代码(TrySZSort),否则使用快速排序的泛型实现。
[MethodImpl(MethodImplOptions.InternalCall), ReliabilityContract(Consistency.MayCorruptInstance, Cer.MayFail)] private static extern bool TrySZSort(Array keys, Array items, int left, int right); internal static void QuickSort(T[] keys, int left, int right, IComparer<T> comparer) { do { int a = left; int b = right; int num3 = a + ((b - a) >> 1); ArraySortHelper<T>.SwapIfGreaterWithItems(keys, comparer, a, num3); ArraySortHelper<T>.SwapIfGreaterWithItems(keys, comparer, a, b); ArraySortHelper<T>.SwapIfGreaterWithItems(keys, comparer, num3, b); T y = keys[num3]; do { while (comparer.Compare(keys[a], y) < 0) { a++; } while (comparer.Compare(y, keys[b]) < 0) { b--; } if (a > b) { break; } if (a < b) { T local2 = keys[a]; keys[a] = keys[b]; keys[b] = local2; } a++; b--; } while (a <= b); if ((b - left) <= (right - a)) { if (left < b) { ArraySortHelper<T>.QuickSort(keys, left, b, comparer); } left = a; } else { if (a < right) { ArraySortHelper<T>.QuickSort(keys, a, right, comparer); } right = b; } } while (left < right); } private static void SwapIfGreaterWithItems(T[] keys, IComparer<T> comparer, int a, int b) { if ((a != b) && (comparer.Compare(keys[a], keys[b]) > 0)) { T local = keys[a]; keys[a] = keys[b]; keys[b] = local; } }
经测试TrySZSort的运行时间为快速排序泛型实现的1/2左右,据推测TrySZSort与QuickSort采用的是同样的算法,不过TrySZSort的实现是本地代码的指针模式。快速排序是我所知道的通用排序方法当中实际平均运行效率最高的,平均时间复杂度为O(n*log(n)),当然最坏情况可能是O(n^2),但是这种概率也是指数级的。
很多人实现的快速排序都有一个致命缺陷,那就是没有控制递归深度,等于给自己埋了个地雷。我们知道快速排序每一个循环都会把数据分成两段,QuickSort采用双重循环,只递归数量少的一段,另一段继续循环,这样就把递归深度控制在log(N)之内,完美的实现了排雷处理。
QuickSort的实现真的很棒,但是在递归过程中却做了一些无用功,可以看到T[] keys与IComparer<T> comparer这俩一直就没变过,却被不停的传来传去。既然发现问题,那就解决它。
private struct sorter<valueType> { public valueType[] values; public Func<valueType, valueType, int> comparer; public void sort(int startIndex, int endIndex) { do { valueType leftValue = values[startIndex], rightValue = values[endIndex]; int average = (endIndex - startIndex) >> 1; if (average == 0) { if (comparer(leftValue, rightValue) > 0) { values[startIndex] = rightValue; values[endIndex] = leftValue; } break; } int leftIndex = startIndex, rightIndex = endIndex; valueType value = values[average += startIndex]; if (comparer(leftValue, value) <= 0) { if (comparer(value, rightValue) > 0) { values[rightIndex] = value; if (comparer(leftValue, rightValue) <= 0) values[average] = value = rightValue; else { values[leftIndex] = rightValue; values[average] = value = leftValue; } } } else if (comparer(leftValue, rightValue) <= 0) { values[leftIndex] = value; values[average] = value = leftValue; } else { values[rightIndex] = leftValue; if (comparer(value, rightValue) <= 0) { values[leftIndex] = value; values[average] = value = rightValue; } else values[leftIndex] = rightValue; } ++leftIndex; --rightIndex; do { while (comparer(values[leftIndex], value) < 0) ++leftIndex; while (comparer(value, values[rightIndex]) < 0) --rightIndex; if (leftIndex < rightIndex) { leftValue = values[leftIndex]; values[leftIndex] = values[rightIndex]; values[rightIndex] = leftValue; } else { if (leftIndex == rightIndex) { ++leftIndex; --rightIndex; } break; } } while (++leftIndex <= --rightIndex); if (rightIndex - startIndex <= endIndex - leftIndex) { if (startIndex < rightIndex) sort(startIndex, rightIndex); startIndex = leftIndex; } else { if (leftIndex < endIndex) sort(leftIndex, endIndex); endIndex = rightIndex; } } while (startIndex < endIndex); } } public static valueType[] sort<valueType>(valueType[] values, Func<valueType, valueType, int> comparer) { if (values != null && values.Length > 1) { if (comparer == null) throw showjim.sys.pub.nullException; new sorter<valueType> { values = values, comparer = comparer }.sort(0, values.Length - 1); } return values; } public static void sort<valueType>(valueType[] values, Func<valueType, valueType, int> comparer, int startIndex, int count) { if (values != null && values.Length > 1 && count > 1) { if (comparer == null) throw showjim.sys.pub.nullException; int endIndex = startIndex + count; if (startIndex < 0) startIndex = 0; if (endIndex > values.Length) endIndex = values.Length; if (--endIndex - startIndex > 0) { new sorter<valueType> { values = values, comparer = comparer }.sort(startIndex, endIndex); } } }
参考QuickSort主要修改了一下递归传参部分,Release实测运行时间降到QuickSort的80%,可以说QuickSort做了20%无用功。测试数据为100000000个随机整数。
下面用同样的方式来模拟TrySZSort。
private struct intSorter { public int[] values; public void sort(int startIndex, int endIndex) { do { int leftValue = values[startIndex], rightValue = values[endIndex]; int average = (endIndex - startIndex) >> 1; if (average == 0) { if (leftValue > rightValue) { values[startIndex] = rightValue; values[endIndex] = leftValue; } break; } int leftIndex = startIndex, rightIndex = endIndex; int value = values[average += startIndex]; if (leftValue <= value) { if (value > rightValue) { values[rightIndex] = value; if (leftValue <= rightValue) values[average] = value = rightValue; else { values[leftIndex] = rightValue; values[average] = value = leftValue; } } } else if (leftValue <= rightValue) { values[leftIndex] = value; values[average] = value = leftValue; } else { values[rightIndex] = leftValue; if (value <= rightValue) { values[leftIndex] = value; values[average] = value = rightValue; } else values[leftIndex] = rightValue; } ++leftIndex; --rightIndex; do { while (values[leftIndex] < value) ++leftIndex; while (value < values[rightIndex]) --rightIndex; if (leftIndex < rightIndex) { leftValue = values[leftIndex]; values[leftIndex] = values[rightIndex]; values[rightIndex] = leftValue; } else { if (leftIndex == rightIndex) { ++leftIndex; --rightIndex; } break; } } while (++leftIndex <= --rightIndex); if (rightIndex - startIndex <= endIndex - leftIndex) { if (startIndex < rightIndex) sort(startIndex, rightIndex); startIndex = leftIndex; } else { if (leftIndex < endIndex) sort(leftIndex, endIndex); endIndex = rightIndex; } } while (startIndex < endIndex); } } public static int[] sort(int[] values) { if (values != null && values.Length > 1) { new intSorter { values = values }.sort(0, values.Length - 1); } return values; }
Release实测,运行时间为TrySZSort的90%,效果不够理想,我们再把intSorter改为指针模式。
private unsafe struct intSorter { public int* values; public void sort(int* startIndex, int* endIndex) { do { int leftValue = *startIndex, rightValue = *endIndex, average = (int)(endIndex - startIndex) >> 1; if (average == 0) { if (leftValue > rightValue) { *startIndex = rightValue; *endIndex = leftValue; } break; } int* leftIndex = startIndex, rightIndex = endIndex, averageIndex = startIndex + average; int value = *averageIndex; if (leftValue <= value) { if (value > rightValue) { *rightIndex = value; if (leftValue <= rightValue) *averageIndex = value = rightValue; else { *leftIndex = rightValue; *averageIndex = value = leftValue; } } } else if (leftValue <= rightValue) { *leftIndex = value; *averageIndex = value = leftValue; } else { *rightIndex = leftValue; if (value <= rightValue) { *leftIndex = value; *averageIndex = value = rightValue; } else *leftIndex = rightValue; } ++leftIndex; --rightIndex; do { while (*leftIndex < value) ++leftIndex; while (value < *rightIndex) --rightIndex; if (leftIndex < rightIndex) { leftValue = *leftIndex; *leftIndex = *rightIndex; *rightIndex = leftValue; } else { if (leftIndex == rightIndex) { ++leftIndex; --rightIndex; } break; } } while (++leftIndex <= --rightIndex); if (rightIndex - startIndex <= endIndex - leftIndex) { if (startIndex < rightIndex) sort(startIndex, rightIndex); startIndex = leftIndex; } else { if (leftIndex < endIndex) sort(leftIndex, endIndex); endIndex = rightIndex; } } while (startIndex < endIndex); } } public unsafe static int[] sort(int[] values) { if (values != null && values.Length > 1) { fixed (int* valueFixed = values) new intSorter { values = valueFixed }.sort(valueFixed, valueFixed + values.Length - 1); } return values; }
Release实测,运行时间为TrySZSort的80%,所以我猜测TrySZSort与QuickSort使用的是同样的算法。
有时候我们需要对数据进行分页取其中一页,这时候就没必要把所用数据都排好序,这个时候需要对排序程序做一些修改。
private struct rangeSorter<valueType> { public valueType[] values; public Func<valueType, valueType, int> comparer; public int skipCount; public int getEndIndex; public void sort(int startIndex, int endIndex) { do { valueType leftValue = values[startIndex], rightValue = values[endIndex]; int average = (endIndex - startIndex) >> 1; if (average == 0) { if (comparer(leftValue, rightValue) > 0) { values[startIndex] = rightValue; values[endIndex] = leftValue; } break; } average += startIndex; //if (average > getEndIndex) average = getEndIndex; //else if (average < skipCount) average = skipCount; int leftIndex = startIndex, rightIndex = endIndex; valueType value = values[average]; if (comparer(leftValue, value) <= 0) { if (comparer(value, rightValue) > 0) { values[rightIndex] = value; if (comparer(leftValue, rightValue) <= 0) values[average] = value = rightValue; else { values[leftIndex] = rightValue; values[average] = value = leftValue; } } } else if (comparer(leftValue, rightValue) <= 0) { values[leftIndex] = value; values[average] = value = leftValue; } else { values[rightIndex] = leftValue; if (comparer(value, rightValue) <= 0) { values[leftIndex] = value; values[average] = value = rightValue; } else values[leftIndex] = rightValue; } ++leftIndex; --rightIndex; do { while (comparer(values[leftIndex], value) < 0) ++leftIndex; while (comparer(value, values[rightIndex]) < 0) --rightIndex; if (leftIndex < rightIndex) { leftValue = values[leftIndex]; values[leftIndex] = values[rightIndex]; values[rightIndex] = leftValue; } else { if (leftIndex == rightIndex) { ++leftIndex; --rightIndex; } break; } } while (++leftIndex <= --rightIndex); if (rightIndex - startIndex <= endIndex - leftIndex) { if (startIndex < rightIndex && rightIndex >= skipCount) sort(startIndex, rightIndex); if (leftIndex > getEndIndex) break; startIndex = leftIndex; } else { if (leftIndex < endIndex && leftIndex <= getEndIndex) sort(leftIndex, endIndex); if (rightIndex < skipCount) break; endIndex = rightIndex; } } while (startIndex < endIndex); } } public static showjim.sys.collection<valueType> rangeSort<valueType>(valueType[] values, Func<valueType, valueType, int> comparer, int skipCount, int getCount) { if (values == null) return null; if (comparer == null) throw showjim.sys.pub.nullException; showjim.sys.array.range range = new showjim.sys.array.range(values.Length, skipCount, getCount); if ((getCount = range.getCount) > 0) { if (getCount > 1) { new rangeSorter<valueType> { values = values, comparer = comparer, skipCount = range.skipCount, getEndIndex = range.endIndex - 1 }.sort(0, values.Length - 1); } return new collection<valueType>(values, range.skipCount, getCount); } return new showjim.sys.collection<valueType>(); } public static showjim.sys.collection<valueType> rangeSort<valueType>(valueType[] values, int startIndex, int count, Func<valueType, valueType, int> comparer, int skipCount, int getCount) { if (values == null) return null; if (comparer == null) throw showjim.sys.pub.nullException; int endIndex = startIndex + count; if (count < 0) count = 0; if (endIndex > values.Length) endIndex = values.Length; if ((count = endIndex - startIndex) > 0) { int getEndIndex = (skipCount += startIndex) + getCount; if (skipCount < startIndex) skipCount = startIndex; if (getEndIndex > endIndex) getEndIndex = endIndex; if ((getCount = getEndIndex - skipCount) > 0) { if (getCount > 1) { new rangeSorter<valueType> { values = values, comparer = comparer, skipCount = skipCount, getEndIndex = --getEndIndex }.sort(startIndex, --endIndex); } return new collection<valueType>(values, skipCount, getCount); } } return new showjim.sys.collection<valueType>(); } /// <summary> /// 数据记录范围 /// </summary> public struct range { /// <summary> /// 数据总量 /// </summary> private int Count; /// <summary> /// 起始位置 /// </summary> private int StartIndex; /// <summary> /// 跳过记录数 /// </summary> public int skipCount { get { return StartIndex; } } /// <summary> /// 结束位置 /// </summary> private int EndIndex; /// <summary> /// 结束位置 /// </summary> public int endIndex { get { return EndIndex; } } /// <summary> /// 获取记录数 /// </summary> public int getCount { get { return EndIndex - StartIndex; } } /// <summary> /// 数据记录范围 /// </summary> /// <param name="count">数据总量</param> /// <param name="skipCount">跳过记录数</param> /// <param name="getCount">获取记录数</param> public range(int count, int skipCount, int getCount) { Count = count < 0 ? 0 : count; if (skipCount < Count && getCount != 0) { if (getCount > 0) { if (skipCount >= 0) { StartIndex = skipCount; if ((EndIndex = skipCount + getCount) > Count) EndIndex = Count; } else { StartIndex = 0; if ((EndIndex = skipCount + getCount) > Count) EndIndex = Count; else if (EndIndex < 0) EndIndex = 0; } } else { StartIndex = skipCount >= 0 ? skipCount : 0; EndIndex = Count; } } else StartIndex = EndIndex = 0; } }
其实我这里要说的重点不是快速排序,而是借快速排序的实现说递归优化,不知你是否看明白了,因为递归也是很常用的。
快速排序在不同的场合还有不同的优化方法,比如对于较大的结构体可以创建一个用于排序的索引数组,比如对于排序字段计算复杂的可以创建一个用于排序的缓存值数组,比如数据段小于8的时候采用硬编码方式。