编写并发程序 Inversion
做完了 scala parallel 课程作业后,觉得 scala 写并发程序的便捷性是 java 永远都追不上的。scala 的Future 和 Promise,java 里 Future 和 CompleteFuture 实现了类似的功能,但是使用的便捷性还差的很远,java.util.Future 本身 API 较少,不支持链式操作。CompleteFuture 丰富了 Future 的 API,但是也不好用。
这里用 Scala parallel 学到的东西计算 Inversion. Inversion 叫做逆序对,它的 nlogn 算法的思想是在 merge sort 的 Merge 阶段计算逆序对的个数。先列出单线程解法。
public static long sort(List<Integer> numbers, int left, int right) {
if(left >= right) return 0L;
if(left + 1 == right) return 0L;
int mid = (right - left) / 2 + left;
long leftInversion = sort(numbers, left, mid);
long rightInversion = sort(numbers, mid, right);
long mergeInversion = merge(numbers, left, mid, right);
return (leftInversion + rightInversion + mergeInversion);
}
public static long merge(List<Integer> numbers, int left, int mid, int right) {
List<Integer> buf = new ArrayList<>();
int leftCursor = left, rightCursor = mid;
long inversion = 0;
while(leftCursor < mid && rightCursor < right) {
if(numbers.get(leftCursor) <= numbers.get(rightCursor)) buf.add(numbers.get(leftCursor ++));
else {
buf.add(numbers.get(rightCursor ++));
inversion += (mid - leftCursor);
}
while(leftCursor < mid) buf.add(numbers.get(leftCursor ++));
while(rightCursor < right) buf.add(numbers.get(rightCursor ++));
for(int i = 0; i < (right - left); i ++) numbers.set(i+left, buf.get(i));
return inversion;
}
}
做 benchmark 一定要注意同一段程序要 run 多遍,以最后一遍的运行时间为准,因为预热阶段包括对内存的填充,线程的创建等等。
在我的 4 核 i7 mac 上跑了三轮,10万数字的 inversion, 时间分别是 100ms, 70ms, 40ms.
然后是并行解法。并行解法使用了 ForkJoinPool,别的 threadPool 也是一样的,但是性能上是否有区别就不知道了。
为了避免每次执行任务都要创建 ForkJoinTask, 先写一个 wrapper.
public abstract class TaskScheduler {
public abstract <T> ForkJoinTask<T> schedule(Function<Void, T> func);
}
public class DefaultTaskScheduler extends TaskScheduler {
public <T> ForkJoinTask<T> schedule(Function<Void, T> func) {
ForkJoinTask<T> task = new ForkJoinTask<T>() {
protected T compute() { return func.apply(null); }
};
ForkJoinCom.pool.execute(task);
return task;
}
}
有了这个 Wrapper 以后,就可以通过 schedule 函数直接把运算逻辑变成 ForkJoinTask。
merge 是顺序执行的,写不出它的并行实现,但是 sort 函数是分而治之算法,每次把 List 划分为不相交的两段,可以并行的对这两段排序。
public static long parSort(List<Integer> nums, int left, int right, int threshold) {
if(right - left <= threshold) return Inversion.sort(nums, left, right);
int mid = (right - left) /2 + left;
ForkJoinTask<Long> leftTask = ForkJoinCom.scheduler.schedule(Void -> parSort(nums, left, mid, threshold));
ForkJoinTask<Long> rightTask = ForkJoinCom.scheduler.schedule(Void -> parSort(nums, mid, right, threshold));
long leftInversions = leftTask.join();
long rightInversions = rightTask.join();
long mergeInversions = Inversion.merge(numbers, left, mid, right);
return leftInversions + rightInversions + mergeInversions;
}
到这里,并行解法就算写完了,但是性能提升的并不明显。尝试调整 threshold, 调整 ForkJoinPool 的线程数目,效果依然不明显。回忆 scala 作业题里老师给出的实现,突然想到,当 leftTask, rightTask 正在执行的时候,当前线程只是傻等着,什么都没干,这是对 CPU 资源的浪费。照着这个思路稍微修改了下 parSort 方法: