利用多线程对数组进行归并排序
多线程处理归并排序的方法一般为:
假设有n个线程同步处理,就将数组等分成n份,每个线程处理一份,再对最后n个有序数组进行归并。
为了使对整个算法具有可扩展性,即线程数n可以自定义,笔者将线程类、处理数组类等进行封装,分为最主要的4个类:Array, Merge, MyThread, MoreThreads
,代码如下:
/*Array.java*/
import java.util.ArrayList;
/**
* @author duyue
*
* 这个类用来处理数组
*
* 原理:
* 创建待排序数组成功后,需要配合多线程(假设有n个线程)分别排序,
* 需要将数组尽量等分成n个分数组(保存到列表中),由n个线程分别归
* 并排序,并将各个有序数组(再次保存到列表中),最后整合(不归并
* 整合)并覆盖原数组,等待最后归并。
*/
class Array {
/**
* 构造一个保存数组的列表,用于保存分割后的分数组
*/
static ArrayList<int[]> arrayList = new ArrayList<int[]>();
/**
* @param length 数组长度
* @return 待排序的数组
*/
static int[] createArray(int length) {
int[] array = new int[length];
for (int i = 0; i < length; i++) {
array[i] = (int) (Math.random() * 10000);
}
return array;
}
/**
* @param array 待分割(多线程排序需要)的数组
* @param num 线程数,即要分割的份数
*/
static void divideArray(int[] array, int num) {
int k = 0; //记录原数组的复制进度,k代表当前数组的复制初始点
for (int i = 0; i < num; i++) {
int point = array.length / num; //分数组的长度
int[] a = new int[0]; //保存分数组
/*考虑到不够整除的情况,将剩余的项全部放在最后一个分数组中*/
if (i != num - 1) a = new int[point];
if (i == num - 1) a = new int[array.length - k];
/*将array[k, k + a.length -1]复制到a[0, a.length]中*/
System.arraycopy(array, k, a, 0, a.length);
arrayList.add(a); //把得到的分数组保存在列表中
k += point; //移动复制初始点
}
}
/**
* @param newArray 由有序分数组整合(不归并)的新数组
* @param num 有序分数组的个数,即由num个线程分别排序后得到的数组数,也就是线程数
*/
static void newArray(int[] newArray, int num) {
/*原理与divideArray方法相似*/
int k = 0; //记录新数组的待整合初始点
/*把列表元素(即数组)逐个复制到新的数组中*/
for (int i = 0; i < num; i++) {
System.arraycopy(arrayList.get(i), 0, newArray, k, arrayList.get(i).length);
k += arrayList.get(i).length; //移动待整合初始点
}
}
}
/*Merge.java*/
/**
* @author duyue
*
* 这是对数组进行归并排序的类
*/
class Merge {
private int[] temp; //暂时存放待排序数组的temp数组
/**
* @param a 待排序的数组由构造器传递到类中
*/
Merge(int[] a) {
temp = new int[a.length];
}
public void sort(int[] a) {
sort(a, 0, a.length - 1);
}
public void sort(int[] a, int low, int high) {
if (low >= high)
return;
int mid = low + (high - low) / 2;
sort(a, low, mid); //将左半边排序
sort(a, mid + 1, high); //将左半边排序
merge(a, low, mid, high); //归并结果
}
/**
* @param a 待归并的数组,其中a[low,mid]和a[mid+1,high]都是有序的
*/
public void merge(int[] a, int low, int mid, int high) {
int i = low, j = mid + 1;
/*将a[low,high]复制到temp[low,high]*/
System.arraycopy(a, low, temp, low, high - low + 1);
/*归并到a[low,high]*/
for (int k = low; k <= high; k++) {
if (i > mid)
a[k] = temp[j++];
else if (j > high)
a[k] = temp[i++];
else if (temp[j] < temp[i])
a[k] = temp[j++];
else
a[k] = temp[i++];
}
}
}
/*MyThread.java*/
import java.util.concurrent.CountDownLatch;
/**
* @author duyue
*
* 这个类用来定义线程,使其能够对数组进行归并排序处理
*/
class MyThread extends Thread {
public int[] aux; //定义一个数组,用来保存待处理的数组
private CountDownLatch latch; //定义这个类用来等待各个线程都完成工作,再进行下一步操作
/*通过构造器将待处理的数组传递到线程的类中*/
public MyThread(int[] aux, CountDownLatch latch) {
this.aux = aux;
this.latch = latch;
}
public void run() {
Merge mergeThread = new Merge(aux);
mergeThread.sort(aux);
latch.countDown();
}
}
/*MoreThreads.java*/
import java.util.ArrayList;
import java.util.concurrent.CountDownLatch;
/**
* @author duyue
*
* 本类是多线程处理归并排序的核心部分。
*
* 原理:
* 由用户指定线程数,例如n个线程,将数组分为n份,分别用n个线程对这n个数组进行归并排序,
* 得到n个有序分数组,再对这n个有序数组归并就得出最后的结果。
* 线程数越多,相应的速度就会越快。
* 要处理的数组长度越长,多线程与单线程的对比就越大。
*/
class MoreThreads {
/**
* @param num 线程数,由用户定义
*/
MoreThreads(int num) {
System.out.println("现在是" + num + "个线程处理归并排序:");
int length = 100; //数组总长度
for (int j = 0; j < 6; j++) {
/*记录起始时间*/
long beginTime = System.currentTimeMillis();
/*创建待排序的数组*/
int[] myArray = Array.createArray(length);
/*将数组近乎等分成num份,以便利用多线程对各个数组排序*/
Array.divideArray(myArray, num);
/*
* 对各个数组利用num个线程同步排序。
* 将num个线程保存在列表threads中,方便将各个线程处理后的数组调出。
* CountDownLatch类用于等待所有的线程都工作完成后,进行最终的归并。
*/
ArrayList<MyThread> threads = new ArrayList<MyThread>();
CountDownLatch latch = new CountDownLatch(num);
for (int i = 0; i < num; i++) {
MyThread thread = new MyThread(Array.arrayList.get(i), latch);
thread.start();
threads.add(thread);
}
try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
/*
* 清空原装有数组列表中的元素,
* 将排序后的各个数组从threads列表中调出,添加到数组列表Array中
*/
Array.arrayList.clear();
for (int i = 0; i < num; i++) {
Array.arrayList.add(threads.get(i).aux);
}
/*把各个排序后数组规整到长数组中,并对三个有序数组归并到一个数组中*/
Array.newArray(myArray, num);
/*
* 对有序数组进行归并
* 归并原理:
* 将第一个有序分数组(即第一个线程排序后的数组)与其下一个有序分数组(即第二个线程
* 排序后的数组)归并成一个数组,再把归并的数组与其下一个有序分数组(即第三个线程排
* 序后的数组)归并,依此类推.
*/
int low = 0; //整合后的长数组myArray的首位
int mid = -1; //待归并的第一个有序分数组的末位
int high; //待归并的第二个有序分数组的末位
for (int i = 0; i < num - 1; i++) {
Merge merge = new Merge(myArray);
mid = Array.arrayList.get(i).length + mid;
high = mid + Array.arrayList.get(i + 1).length;
merge.merge(myArray, low, mid, high);
}
/*记录结束时间*/
long endTime = System.currentTimeMillis();
System.out.println(length + "项数组归并排序的时间:" + (endTime - beginTime) + "ms");
length = length * 10;
Array.arrayList.clear(); //清空列表内容,对下一次循环不造成影响
}
}
}
运行以下代码即可测试:
/*TestThread*/
import java.util.Scanner;
/**
* @author duyue
*
* 这是一个测试类,用于展示结果。
*/
public class TestThread {
public static void main(String[] args) {
new MoreThreads(1);
System.out.println("--------------------------------");
new MoreThreads(2);
System.out.println("--------------------------------");
new MoreThreads(3);
System.out.println("--------------------------------");
System.out.println("你还想尝试更多线程处理归并排序吗?(y:yes, n:no)");
while (true) {
Scanner in = new Scanner(System.in);
String s = in.nextLine();
if (s.equals("n")) {
System.out.println("byebye!");
in.close();
break;
} else if (s.equals("y")) {
System.out.println("请输入要尝试的线程数:");
new MoreThreads(in.nextInt());
System.out.println("--------------------------------");
System.out.println("你还想尝试更多线程处理归并排序吗?(y:yes, n:no)");
} else
System.out.println("输入错误!请重新输入");
}
}
}