C#效率极子 - 精益求精
自动化+性能优化

很多人都说.net排序的效率高,抱着学习的态度观摩.net源代码。当数据类型为基本类型并且未指定排序规则时用的是本地代码(TrySZSort),否则使用快速排序的泛型实现。

QuickSort
[MethodImpl(MethodImplOptions.InternalCall), ReliabilityContract(Consistency.MayCorruptInstance, Cer.MayFail)]
private static extern bool TrySZSort(Array keys, Array items, int left, int right);

internal static void QuickSort(T[] keys, int left, int right, IComparer<T> comparer)
{
    do
    {
        int a = left;
        int b = right;
        int num3 = a + ((b - a) >> 1);
        ArraySortHelper<T>.SwapIfGreaterWithItems(keys, comparer, a, num3);
        ArraySortHelper<T>.SwapIfGreaterWithItems(keys, comparer, a, b);
        ArraySortHelper<T>.SwapIfGreaterWithItems(keys, comparer, num3, b);
        T y = keys[num3];
        do
        {
            while (comparer.Compare(keys[a], y) < 0)
            {
                a++;
            }
            while (comparer.Compare(y, keys[b]) < 0)
            {
                b--;
            }
            if (a > b)
            {
                break;
            }
            if (a < b)
            {
                T local2 = keys[a];
                keys[a] = keys[b];
                keys[b] = local2;
            }
            a++;
            b--;
        }
        while (a <= b);
        if ((b - left) <= (right - a))
        {
            if (left < b)
            {
                ArraySortHelper<T>.QuickSort(keys, left, b, comparer);
            }
            left = a;
        }
        else
        {
            if (a < right)
            {
                ArraySortHelper<T>.QuickSort(keys, a, right, comparer);
            }
            right = b;
        }
    }
    while (left < right);
}
private static void SwapIfGreaterWithItems(T[] keys, IComparer<T> comparer, int a, int b)
{
    if ((a != b) && (comparer.Compare(keys[a], keys[b]) > 0))
    {
        T local = keys[a];
        keys[a] = keys[b];
        keys[b] = local;
    }
}

经测试TrySZSort的运行时间为快速排序泛型实现的1/2左右,据推测TrySZSort与QuickSort采用的是同样的算法,不过TrySZSort的实现是本地代码的指针模式。快速排序是我所知道的通用排序方法当中实际平均运行效率最高的,平均时间复杂度为O(n*log(n)),当然最坏情况可能是O(n^2),但是这种概率也是指数级的。
很多人实现的快速排序都有一个致命缺陷,那就是没有控制递归深度,等于给自己埋了个地雷。我们知道快速排序每一个循环都会把数据分成两段,QuickSort采用双重循环,只递归数量少的一段,另一段继续循环,这样就把递归深度控制在log(N)之内,完美的实现了排雷处理。


QuickSort的实现真的很棒,但是在递归过程中却做了一些无用功,可以看到T[] keys与IComparer<T> comparer这俩一直就没变过,却被不停的传来传去。既然发现问题,那就解决它。

Sort
        private struct sorter<valueType>
        {
            public valueType[] values;
            public Func<valueType, valueType, int> comparer;
            public void sort(int startIndex, int endIndex)
            {
                do
                {
                    valueType leftValue = values[startIndex], rightValue = values[endIndex];
                    int average = (endIndex - startIndex) >> 1;
                    if (average == 0)
                    {
                        if (comparer(leftValue, rightValue) > 0)
                        {
                            values[startIndex] = rightValue;
                            values[endIndex] = leftValue;
                        }
                        break;
                    }
                    int leftIndex = startIndex, rightIndex = endIndex;
                    valueType value = values[average += startIndex];
                    if (comparer(leftValue, value) <= 0)
                    {
                        if (comparer(value, rightValue) > 0)
                        {
                            values[rightIndex] = value;
                            if (comparer(leftValue, rightValue) <= 0) values[average] = value = rightValue;
                            else
                            {
                                values[leftIndex] = rightValue;
                                values[average] = value = leftValue;
                            }
                        }
                    }
                    else if (comparer(leftValue, rightValue) <= 0)
                    {
                        values[leftIndex] = value;
                        values[average] = value = leftValue;
                    }
                    else
                    {
                        values[rightIndex] = leftValue;
                        if (comparer(value, rightValue) <= 0)
                        {
                            values[leftIndex] = value;
                            values[average] = value = rightValue;
                        }
                        else values[leftIndex] = rightValue;
                    }
                    ++leftIndex;
                    --rightIndex;
                    do
                    {
                        while (comparer(values[leftIndex], value) < 0) ++leftIndex;
                        while (comparer(value, values[rightIndex]) < 0) --rightIndex;
                        if (leftIndex < rightIndex)
                        {
                            leftValue = values[leftIndex];
                            values[leftIndex] = values[rightIndex];
                            values[rightIndex] = leftValue;
                        }
                        else
                        {
                            if (leftIndex == rightIndex)
                            {
                                ++leftIndex;
                                --rightIndex;
                            }
                            break;
                        }
                    }
                    while (++leftIndex <= --rightIndex);
                    if (rightIndex - startIndex <= endIndex - leftIndex)
                    {
                        if (startIndex < rightIndex) sort(startIndex, rightIndex);
                        startIndex = leftIndex;
                    }
                    else
                    {
                        if (leftIndex < endIndex) sort(leftIndex, endIndex);
                        endIndex = rightIndex;
                    }
                }
                while (startIndex < endIndex);
            }
        }
        public static valueType[] sort<valueType>(valueType[] values, Func<valueType, valueType, int> comparer)
        {
            if (values != null && values.Length > 1)
            {
                if (comparer == null) throw showjim.sys.pub.nullException;
                new sorter<valueType> { values = values, comparer = comparer }.sort(0, values.Length - 1);
            }
            return values;
        }
        public static void sort<valueType>(valueType[] values, Func<valueType, valueType, int> comparer, int startIndex, int count)
        {
            if (values != null && values.Length > 1 && count > 1)
            {
                if (comparer == null) throw showjim.sys.pub.nullException;
                int endIndex = startIndex + count;
                if (startIndex < 0) startIndex = 0;
                if (endIndex > values.Length) endIndex = values.Length;
                if (--endIndex - startIndex > 0)
                {
                    new sorter<valueType> { values = values, comparer = comparer }.sort(startIndex, endIndex);
                }
            }
        }

参考QuickSort主要修改了一下递归传参部分,Release实测运行时间降到QuickSort的80%,可以说QuickSort做了20%无用功。测试数据为100000000个随机整数。

下面用同样的方式来模拟TrySZSort。

Sort
        private struct intSorter
        {
            public int[] values;
            public void sort(int startIndex, int endIndex)
            {
                do
                {
                    int leftValue = values[startIndex], rightValue = values[endIndex];
                    int average = (endIndex - startIndex) >> 1;
                    if (average == 0)
                    {
                        if (leftValue > rightValue)
                        {
                            values[startIndex] = rightValue;
                            values[endIndex] = leftValue;
                        }
                        break;
                    }
                    int leftIndex = startIndex, rightIndex = endIndex;
                    int value = values[average += startIndex];
                    if (leftValue <= value)
                    {
                        if (value > rightValue)
                        {
                            values[rightIndex] = value;
                            if (leftValue <= rightValue) values[average] = value = rightValue;
                            else
                            {
                                values[leftIndex] = rightValue;
                                values[average] = value = leftValue;
                            }
                        }
                    }
                    else if (leftValue <= rightValue)
                    {
                        values[leftIndex] = value;
                        values[average] = value = leftValue;
                    }
                    else
                    {
                        values[rightIndex] = leftValue;
                        if (value <= rightValue)
                        {
                            values[leftIndex] = value;
                            values[average] = value = rightValue;
                        }
                        else values[leftIndex] = rightValue;
                    }
                    ++leftIndex;
                    --rightIndex;
                    do
                    {
                        while (values[leftIndex] < value) ++leftIndex;
                        while (value < values[rightIndex]) --rightIndex;
                        if (leftIndex < rightIndex)
                        {
                            leftValue = values[leftIndex];
                            values[leftIndex] = values[rightIndex];
                            values[rightIndex] = leftValue;
                        }
                        else
                        {
                            if (leftIndex == rightIndex)
                            {
                                ++leftIndex;
                                --rightIndex;
                            }
                            break;
                        }
                    }
                    while (++leftIndex <= --rightIndex);
                    if (rightIndex - startIndex <= endIndex - leftIndex)
                    {
                        if (startIndex < rightIndex) sort(startIndex, rightIndex);
                        startIndex = leftIndex;
                    }
                    else
                    {
                        if (leftIndex < endIndex) sort(leftIndex, endIndex);
                        endIndex = rightIndex;
                    }
                }
                while (startIndex < endIndex);
            }
        }
        public static int[] sort(int[] values)
        {
            if (values != null && values.Length > 1)
            {
                new intSorter { values = values }.sort(0, values.Length - 1);
            }
            return values;
        }

Release实测,运行时间为TrySZSort的90%,效果不够理想,我们再把intSorter改为指针模式。

Sort<int*>
        private unsafe struct intSorter
        {
            public int* values;
            public void sort(int* startIndex, int* endIndex)
            {
                do
                {
                    int leftValue = *startIndex, rightValue = *endIndex, average = (int)(endIndex - startIndex) >> 1;
                    if (average == 0)
                    {
                        if (leftValue > rightValue)
                        {
                            *startIndex = rightValue;
                            *endIndex = leftValue;
                        }
                        break;
                    }
                    int* leftIndex = startIndex, rightIndex = endIndex, averageIndex = startIndex + average;
                    int value = *averageIndex;
                    if (leftValue <= value)
                    {
                        if (value > rightValue)
                        {
                            *rightIndex = value;
                            if (leftValue <= rightValue) *averageIndex = value = rightValue;
                            else
                            {
                                *leftIndex = rightValue;
                                *averageIndex = value = leftValue;
                            }
                        }
                    }
                    else if (leftValue <= rightValue)
                    {
                        *leftIndex = value;
                        *averageIndex = value = leftValue;
                    }
                    else
                    {
                        *rightIndex = leftValue;
                        if (value <= rightValue)
                        {
                            *leftIndex = value;
                            *averageIndex = value = rightValue;
                        }
                        else *leftIndex = rightValue;
                    }
                    ++leftIndex;
                    --rightIndex;
                    do
                    {
                        while (*leftIndex < value) ++leftIndex;
                        while (value < *rightIndex) --rightIndex;
                        if (leftIndex < rightIndex)
                        {
                            leftValue = *leftIndex;
                            *leftIndex = *rightIndex;
                            *rightIndex = leftValue;
                        }
                        else
                        {
                            if (leftIndex == rightIndex)
                            {
                                ++leftIndex;
                                --rightIndex;
                            }
                            break;
                        }
                    }
                    while (++leftIndex <= --rightIndex);
                    if (rightIndex - startIndex <= endIndex - leftIndex)
                    {
                        if (startIndex < rightIndex) sort(startIndex, rightIndex);
                        startIndex = leftIndex;
                    }
                    else
                    {
                        if (leftIndex < endIndex) sort(leftIndex, endIndex);
                        endIndex = rightIndex;
                    }
                }
                while (startIndex < endIndex);
            }
        }
        public unsafe static int[] sort(int[] values)
        {
            if (values != null && values.Length > 1)
            {
                fixed (int* valueFixed = values) new intSorter { values = valueFixed }.sort(valueFixed, valueFixed + values.Length - 1);
            }
            return values;
        }

Release实测,运行时间为TrySZSort的80%,所以我猜测TrySZSort与QuickSort使用的是同样的算法。

 

有时候我们需要对数据进行分页取其中一页,这时候就没必要把所用数据都排好序,这个时候需要对排序程序做一些修改。

RangeSort
        private struct rangeSorter<valueType>
        {
            public valueType[] values;
            public Func<valueType, valueType, int> comparer;
            public int skipCount;
            public int getEndIndex;
            public void sort(int startIndex, int endIndex)
            {
                do
                {
                    valueType leftValue = values[startIndex], rightValue = values[endIndex];
                    int average = (endIndex - startIndex) >> 1;
                    if (average == 0)
                    {
                        if (comparer(leftValue, rightValue) > 0)
                        {
                            values[startIndex] = rightValue;
                            values[endIndex] = leftValue;
                        }
                        break;
                    }
                    average += startIndex;
                    //if (average > getEndIndex) average = getEndIndex;
                    //else if (average < skipCount) average = skipCount;
                    int leftIndex = startIndex, rightIndex = endIndex;
                    valueType value = values[average];
                    if (comparer(leftValue, value) <= 0)
                    {
                        if (comparer(value, rightValue) > 0)
                        {
                            values[rightIndex] = value;
                            if (comparer(leftValue, rightValue) <= 0) values[average] = value = rightValue;
                            else
                            {
                                values[leftIndex] = rightValue;
                                values[average] = value = leftValue;
                            }
                        }
                    }
                    else if (comparer(leftValue, rightValue) <= 0)
                    {
                        values[leftIndex] = value;
                        values[average] = value = leftValue;
                    }
                    else
                    {
                        values[rightIndex] = leftValue;
                        if (comparer(value, rightValue) <= 0)
                        {
                            values[leftIndex] = value;
                            values[average] = value = rightValue;
                        }
                        else values[leftIndex] = rightValue;
                    }
                    ++leftIndex;
                    --rightIndex;
                    do
                    {
                        while (comparer(values[leftIndex], value) < 0) ++leftIndex;
                        while (comparer(value, values[rightIndex]) < 0) --rightIndex;
                        if (leftIndex < rightIndex)
                        {
                            leftValue = values[leftIndex];
                            values[leftIndex] = values[rightIndex];
                            values[rightIndex] = leftValue;
                        }
                        else
                        {
                            if (leftIndex == rightIndex)
                            {
                                ++leftIndex;
                                --rightIndex;
                            }
                            break;
                        }
                    }
                    while (++leftIndex <= --rightIndex);
                    if (rightIndex - startIndex <= endIndex - leftIndex)
                    {
                        if (startIndex < rightIndex && rightIndex >= skipCount) sort(startIndex, rightIndex);
                        if (leftIndex > getEndIndex) break;
                        startIndex = leftIndex;
                    }
                    else
                    {
                        if (leftIndex < endIndex && leftIndex <= getEndIndex) sort(leftIndex, endIndex);
                        if (rightIndex < skipCount) break;
                        endIndex = rightIndex;
                    }
                }
                while (startIndex < endIndex);
            }
        }
        public static showjim.sys.collection<valueType> rangeSort<valueType>(valueType[] values, Func<valueType, valueType, int> comparer, int skipCount, int getCount)
        {
            if (values == null) return null;
            if (comparer == null) throw showjim.sys.pub.nullException;
            showjim.sys.array.range range = new showjim.sys.array.range(values.Length, skipCount, getCount);
            if ((getCount = range.getCount) > 0)
            {
                if (getCount > 1)
                {
                    new rangeSorter<valueType>
                    {
                        values = values,
                        comparer = comparer,
                        skipCount = range.skipCount,
                        getEndIndex = range.endIndex - 1
                    }.sort(0, values.Length - 1);
                }
                return new collection<valueType>(values, range.skipCount, getCount);
            }
            return new showjim.sys.collection<valueType>();
        }
        public static showjim.sys.collection<valueType> rangeSort<valueType>(valueType[] values, int startIndex, int count, Func<valueType, valueType, int> comparer, int skipCount, int getCount)
        {
            if (values == null) return null;
            if (comparer == null) throw showjim.sys.pub.nullException;
            int endIndex = startIndex + count;
            if (count < 0) count = 0;
            if (endIndex > values.Length) endIndex = values.Length;
            if ((count = endIndex - startIndex) > 0)
            {
                int getEndIndex = (skipCount += startIndex) + getCount;
                if (skipCount < startIndex) skipCount = startIndex;
                if (getEndIndex > endIndex) getEndIndex = endIndex;
                if ((getCount = getEndIndex - skipCount) > 0)
                {
                    if (getCount > 1)
                    {
                        new rangeSorter<valueType>
                        {
                            values = values,
                            comparer = comparer,
                            skipCount = skipCount,
                            getEndIndex = --getEndIndex
                        }.sort(startIndex, --endIndex);
                    }
                    return new collection<valueType>(values, skipCount, getCount);
                }
            }
            return new showjim.sys.collection<valueType>();
        }
        /// <summary>
        /// 数据记录范围
        /// </summary>
        public struct range
        {
            /// <summary>
            /// 数据总量
            /// </summary>
            private int Count;
            /// <summary>
            /// 起始位置
            /// </summary>
            private int StartIndex;
            /// <summary>
            /// 跳过记录数
            /// </summary>
            public int skipCount
            {
                get { return StartIndex; }
            }
            /// <summary>
            /// 结束位置
            /// </summary>
            private int EndIndex;
            /// <summary>
            /// 结束位置
            /// </summary>
            public int endIndex
            {
                get { return EndIndex; }
            }
            /// <summary>
            /// 获取记录数
            /// </summary>
            public int getCount
            {
                get { return EndIndex - StartIndex; }
            }

            /// <summary>
            /// 数据记录范围
            /// </summary>
            /// <param name="count">数据总量</param>
            /// <param name="skipCount">跳过记录数</param>
            /// <param name="getCount">获取记录数</param>
            public range(int count, int skipCount, int getCount)
            {
                Count = count < 0 ? 0 : count;
                if (skipCount < Count && getCount != 0)
                {
                    if (getCount > 0)
                    {
                        if (skipCount >= 0)
                        {
                            StartIndex = skipCount;
                            if ((EndIndex = skipCount + getCount) > Count) EndIndex = Count;
                        }
                        else
                        {
                            StartIndex = 0;
                            if ((EndIndex = skipCount + getCount) > Count) EndIndex = Count;
                            else if (EndIndex < 0) EndIndex = 0;
                        }
                    }
                    else
                    {
                        StartIndex = skipCount >= 0 ? skipCount : 0;
                        EndIndex = Count;
                    }
                }
                else StartIndex = EndIndex = 0;
            }
        }

 

其实我这里要说的重点不是快速排序,而是借快速排序的实现说递归优化,不知你是否看明白了,因为递归也是很常用的。
快速排序在不同的场合还有不同的优化方法,比如对于较大的结构体可以创建一个用于排序的索引数组,比如对于排序字段计算复杂的可以创建一个用于排序的缓存值数组,比如数据段小于8的时候采用硬编码方式。

posted on 2012-05-11 14:43  肖进  阅读(1802)  评论(1编辑  收藏  举报