JDK-In-Action-ForkJoin-Get-Start

ForkJoin框架概述

一个特殊的执行程序服务ExecutorService实现.
ForkJoin框架适用于执行计算密集型任务,通过再每个CPU核上使用一个线程来加速运算.
这些任务应该可以从一个大任务分解成多个小的子任务,分别计算后再汇总结果.

框架使用一种称为工作窃取(work stealing)的方法来平衡可用线程的工作负载.
每个线程都包含一个自己的线程安全的双端队列(deque),自身线程从队列的头获取任务,当自身队列可用任务为空后,从其他线程的队列的尾部偷取任务.

JDK8的并行流API,底层就是使用的ForkJoinPool.

API 概览

  • ForkJoinPool
    创建一个并行线程池,提交任务的入口.可选三个指定的参数(并行度,线程工厂,非受检异常处理器,异步模式).

默认提供一个公共的ForkJoinPool.commonPool(),这个默认构造的池确保系统退出前中断执行中的任务.

  • RecursiveTask
    用于定义有返回值的任务
  • RecursiveAction
    用于无返回值的任务

示例

统计一个数组中满足条件的元素个数


import java.util.Arrays;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

/**
 * 统计一个数组中满足条件的元素个数
 */
public class FilterCounterExample {

    public static void main(String[] args) {

        final int size = 1_000_000;
        double[] arr = new double[size];
        for (int i = 0; i < size; i++) {
            arr[i] = Math.random();
        }
        //JDK7的解法
        JDK7(arr);
        //JDK8的解法
        JDK8(arr);
    }

    private static void JDK8(double[] arr) {
        //使用并行流,底层就是运行的FORK-JOIN
        final long count = Arrays.stream(arr).parallel().filter(d -> d > 0.5).count();
        System.out.println(count);
    }

    private static void JDK7(double[] arr) {
        Counter counter = new Counter(arr, 0, arr.length, new Filter() {
            @Override
            public boolean accept(double e) {
                return e > 0.5;
            }
        });

        ForkJoinPool pool = ForkJoinPool.commonPool();
        pool.invoke(counter);
        final Integer count = counter.join();
        System.out.println(count);
    }


}

interface Filter {
    boolean accept(double e);
}

class Counter extends RecursiveTask<Integer> {

    public static final int THRESHOLD = 10000;
    private double[] arr;
    private int from;
    private int to;
    private Filter filter;

    public Counter(double[] arr, int from, int to, Filter filter) {
        this.arr = arr;
        this.from = from;
        this.to = to;
        this.filter = filter;
    }

    @Override
    protected Integer compute() {
        if (from - to < THRESHOLD) {
            int count = 0;
            for (int i = from; i < to; i++) {
                if (filter.accept(arr[i])) {
                    count++;
                }
            }
            return count;
        } else {
            int mid = (from + to) / 2;
            Counter one = new Counter(arr, from, mid, filter);
            Counter two = new Counter(arr, mid, to, filter);
            invokeAll(one, two);
            return one.join() + two.join();
        }
    }
}

并行将数组中满足条件的元素值倍增


import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveAction;
import java.util.stream.IntStream;

/**
 * 并行将数组中满足条件的元素个数倍增
 */
public class DoubleValueExample {

    public static void main(String[] args) {

        final int size = 1_000_000;
        double[] arr = new double[size];
        for (int i = 0; i < size; i++) {
            arr[i] = Math.random();
        }
        //JDK7的解法
        JDK7(arr);
        //JDK8的解法
        JDK8(arr);
    }

    private static void JDK8(double[] arr) {
        //使用并行流,底层就是运行的FORK-JOIN
        double before = arr[arr.length / 2];
        IntStream.range(0, arr.length).parallel().forEach(i -> arr[i] *= 2);
        double after = arr[arr.length / 2];
        System.out.println(before + "," + after);
    }

    private static void JDK7(double[] arr) {
        double before = arr[arr.length / 2];
        DoubleValue task = new DoubleValue(arr, 0, arr.length);

        ForkJoinPool pool = ForkJoinPool.commonPool();
        ForkJoinTask<Void> future = pool.submit(task);
        future.join();
        double after = arr[arr.length / 2];
        System.out.println(before + "," + after);
    }


}

class DoubleValue extends RecursiveAction {

    public static final int THRESHOLD = 10000;
    private double[] arr;
    private int from;
    private int to;

    public DoubleValue(double[] arr, int from, int to) {
        this.arr = arr;
        this.from = from;
        this.to = to;
    }

    @Override
    protected void compute() {
        if (from - to < THRESHOLD) {
            int count = 0;
            for (int i = from; i < to; i++) {
                arr[i] *= 2;
            }
        } else {
            int mid = (from + to) / 2;
            DoubleValue one = new DoubleValue(arr, from, mid);
            DoubleValue two = new DoubleValue(arr, mid, to);
            invokeAll(one, two);
        }
    }
}

引用

posted @ 2020-05-10 21:33  onion94  阅读(196)  评论(0编辑  收藏  举报