并行排序算法
Author:Eaglet
今天早晨看到 蛙蛙池塘 的这篇博客 谁能把这个程序的性能提升一倍?---并行排序算法 。促使我写了一个并行排序算法,这个排序算法充分利用多核CPU进行并行计算,从而提高排序的效率。
先简单说一下蛙蛙池塘 给的A,B,C 三种算法(见上面引用的那篇博客),A算法将耗时的平方和开平方计算放到比较函数中,导致Array.Sort 时,每次亮亮比较都要执行平方和开平方计算,其平均算法复杂度为 O(nlog2n) 。 而B 将平方和开平方计算提取出来,算法复杂度降低到 O(n) ,这也就是为什么B比A效率要高很多的缘故。C 和 B 相比,将平方函数替换成了 x*x ,由于少了远程函数调用和Pow函数本身的开销,效率有提高了不少。我在C的基础上编写了D算法,D算法采用并行计算技术,在我的双核笔记本电脑上数据量比较大的情况下,其排序效率较C要提高30%左右。
下面重点介绍这个并行排序算法。算法思路其实很简单,就是将要排序的数组按照处理器数量等分成若干段,然后用和处理器数量等同的线程并行对各个小段进行排序,排序结束和,再在单一线程中对这若干个已经排序的小段进行归并排序,最后输出完整的排序结果。考虑到和.Net 2.0 兼容,我没有用微软提供的并行库,而是用多线程来实现。
下面是测试结果:
n | A | B | C | D |
32768 | 0.7345 | 0.04122 | 0.0216 | 0.0254 |
65535 | 1.5464 | 0.08863 | 0.05139 | 0.05149 |
131072 | 3.2706 | 0.1858 | 0.118 | 0.108 |
262144 | 6.8423 | 0.4056 | 0.29586 | 0.21849 |
524288 | 15.0342 | 0.9689 | 0.7318 | 0.4906 |
1048576 | 31.6312 | 1.9978 | 1.4646 | 1.074 |
2097152 | 66.9134 | 4.1763 | 3.0828 | 2.3095 |
从测试结果上看,当要排序的数组长度较短时,并行排序的效率甚至还没有不进行并行排序高,这主要是多线程的开销造成的。当数组长度增大到25万以上时,并行排序的优势开始体现出来,随着数组长度的增长,排序时间最后基本稳定在但线程排序时间的 74% 左右,其中并行排序的消耗大概在50%左右,归并排序的消耗在 14%左右。由此也可以推断,如果在4CPU的机器上,其排序时间最多可以减少到单线程的 14 + 25 = 39%。8 CPU 为 14 + 12.5 = 26.5%
目前这个算法在归并算法上可能还有提高的余地,如果哪位高手能够进一步提高这个算法,不妨贴出来一起交流交流。
下面分别给出并行排序和归并排序的代码:
并行排序类 ParallelSort
Paralletsort 类是一个通用的泛型,调用起来非常简单,下面给一个简单的int型数组的排序示例:
{
IComparer
}
public void SortInt(int[] array)
{
Sort.ParallelSort<int> parallelSort = new Sort.ParallelSort<int>();
parallelSort.Sort(array, new IntComparer());
}
只要实现一个T类型两两比较的接口,然后调用ParallelSort 的 Sort 方法就可以了,是不是很简单?
下面是 ParallelSort类的代码
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
namespace Sort
{
/// <summary>
/// ParallelSort
/// </summary>
/// <typeparam name="T"></typeparam>
public class ParallelSort<T>
{
enum Status
{
Idle = 0,
Running = 1,
Finish = 2,
}
class ParallelEntity
{
public Status Status;
public T[] Array;
public IComparer<T> Comparer;
public ParallelEntity(Status status, T[] array, IComparer<T> comparer)
{
Status = status;
Array = array;
Comparer = comparer;
}
}
private void ThreadProc(Object stateInfo)
{
ParallelEntity pe = stateInfo as ParallelEntity;
lock (pe)
{
pe.Status = ParallelSort<T>.Status.Running;
Array.Sort(pe.Array, pe.Comparer);
pe.Status = ParallelSort<T>.Status.Finish;
}
}
public void Sort(T[] array, IComparer<T> comparer)
{
//Calculate process count
int processorCount = Environment.ProcessorCount;
//If array.Length too short, do not use Parallel sort
if (processorCount == 1 || array.Length < processorCount)
{
Array.Sort(array, comparer);
return;
}
//Split array
ParallelEntity[] partArray = new ParallelEntity[processorCount];
int remain = array.Length;
int partLen = array.Length / processorCount;
//Copy data to splited array
for (int i = 0; i < processorCount; i++)
{
if (i == processorCount - 1)
{
partArray[i] = new ParallelEntity(Status.Idle, new T[remain], comparer);
}
else
{
partArray[i] = new ParallelEntity(Status.Idle, new T[partLen], comparer);
remain -= partLen;
}
Array.Copy(array, i * partLen, partArray[i].Array, 0, partArray[i].Array.Length);
}
//Parallel sort
for (int i = 0; i < processorCount - 1; i++)
{
ThreadPool.QueueUserWorkItem(new WaitCallback(ThreadProc), partArray[i]);
}
ThreadProc(partArray[processorCount - 1]);
//Wait all threads finish
for (int i = 0; i < processorCount; i++)
{
while (true)
{
lock (partArray[i])
{
if (partArray[i].Status == ParallelSort<T>.Status.Finish)
{
break;
}
}
Thread.Sleep(0);
}
}
//Merge sort
MergeSort<T> mergeSort = new MergeSort<T>();
List<T[]> source = new List<T[]>(processorCount);
foreach (ParallelEntity pe in partArray)
{
source.Add(pe.Array);
}
mergeSort.Sort(array, source, comparer);
}
}
}
多路归并排序类 MergeSort
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace Sort
{
/// <summary>
/// MergeSort
/// </summary>
/// <typeparam name="T"></typeparam>
public class MergeSort<T>
{
public void Sort(T[] destArray, List<T[]> source, IComparer<T> comparer)
{
//Merge Sort
int[] mergePoint = new int[source.Count];
for (int i = 0; i < source.Count; i++)
{
mergePoint[i] = 0;
}
int index = 0;
while (index < destArray.Length)
{
int min = -1;
for (int i = 0; i < source.Count; i++)
{
if (mergePoint[i] >= source[i].Length)
{
continue;
}
if (min < 0)
{
min = i;
}
else
{
if (comparer.Compare(source[i][mergePoint[i]], source[min][mergePoint[min]]) < 0)
{
min = i;
}
}
}
if (min < 0)
{
continue;
}
destArray[index++] = source[min][mergePoint[min]];
mergePoint[min]++;
}
}
}
}
主函数及测试代码 在蛙蛙池塘代码基础上修改
using System.Collections.Generic;
using System.Diagnostics;
namespace Vector4Test
{
public class Vector
{
public double W;
public double X;
public double Y;
public double Z;
public double T;
}
internal class VectorComparer : IComparer<Vector>
{
public int Compare(Vector c1, Vector c2)
{
if (c1 == null || c2 == null)
throw new ArgumentNullException("Both objects must not be null");
double x = Math.Sqrt(Math.Pow(c1.X, 2)
+ Math.Pow(c1.Y, 2)
+ Math.Pow(c1.Z, 2)
+ Math.Pow(c1.W, 2));
double y = Math.Sqrt(Math.Pow(c2.X, 2)
+ Math.Pow(c2.Y, 2)
+ Math.Pow(c2.Z, 2)
+ Math.Pow(c2.W, 2));
if (x > y)
return 1;
else if (x < y)
return -1;
else
return 0;
}
}
internal class VectorComparer2 : IComparer<Vector>
{
public int Compare(Vector c1, Vector c2)
{
if (c1 == null || c2 == null)
throw new ArgumentNullException("Both objects must not be null");
if (c1.T > c2.T)
return 1;
else if (c1.T < c2.T)
return -1;
else
return 0;
}
}
internal class Program
{
private static void Print(Vector[] vectors)
{
//foreach (Vector v in vectors)
//{
// Console.WriteLine(v.T);
//}
}
private static void Main(string[] args)
{
Vector[] vectors = GetVectors();
Console.WriteLine(string.Format("n = {0}", vectors.Length));
Stopwatch watch1 = new Stopwatch();
watch1.Start();
A(vectors);
watch1.Stop();
Console.WriteLine("A sort time: " + watch1.Elapsed);
Print(vectors);
vectors = GetVectors();
watch1.Reset();
watch1.Start();
B(vectors);
watch1.Stop();
Console.WriteLine("B sort time: " + watch1.Elapsed);
Print(vectors);
vectors = GetVectors();
watch1.Reset();
watch1.Start();
C(vectors);
watch1.Stop();
Console.WriteLine("C sort time: " + watch1.Elapsed);
Print(vectors);
vectors = GetVectors();
watch1.Reset();
watch1.Start();
D(vectors);
watch1.Stop();
Console.WriteLine("D sort time: " + watch1.Elapsed);
Print(vectors);
Console.ReadKey();
}
private static Vector[] GetVectors()
{
int n = 1 << 21;
Vector[] vectors = new Vector[n];
Random random = new Random();
for (int i = 0; i < n; i++)
{
vectors[i] = new Vector();
vectors[i].X = random.NextDouble();
vectors[i].Y = random.NextDouble();
vectors[i].Z = random.NextDouble();
vectors[i].W = random.NextDouble();
}
return vectors;
}
private static void A(Vector[] vectors)
{
Array.Sort(vectors, new VectorComparer());
}
private static void B(Vector[] vectors)
{
int n = vectors.Length;
for (int i = 0; i < n; i++)
{
Vector c1 = vectors[i];
c1.T = Math.Sqrt(Math.Pow(c1.X, 2)
+ Math.Pow(c1.Y, 2)
+ Math.Pow(c1.Z, 2)
+ Math.Pow(c1.W, 2));
}
Array.Sort(vectors, new VectorComparer2());
}
private static void C(Vector[] vectors)
{
int n = vectors.Length;
for (int i = 0; i < n; i++)
{
Vector c1 = vectors[i];
c1.T = Math.Sqrt(c1.X * c1.X
+ c1.Y * c1.Y
+ c1.Z * c1.Z
+ c1.W * c1.W);
}
Array.Sort(vectors, new VectorComparer2());
}
private static void D(Vector[] vectors)
{
int n = vectors.Length;
for (int i = 0; i < n; i++)
{
Vector c1 = vectors[i];
c1.T = Math.Sqrt(c1.X * c1.X
+ c1.Y * c1.Y
+ c1.Z * c1.Z
+ c1.W * c1.W);
}
Sort.ParallelSort<Vector> parallelSort = new Sort.ParallelSort<Vector>();
parallelSort.Sort(vectors, new VectorComparer2());
}
}
}