【并发编程】Java7 - ForkJoin,将大任务拆分成小任务

1. 简介

  Java7提供了可以将大任务拆分成小任务执行再合并结果的框架——Fork/Join。其中,将大任务拆分成足够执行的小任务并发执行的过程称为Fork,将这些小任务结果整合后形成最终的结果的过程称为Join。

  Fork/Join框架的具体体现为ForkJoinTask抽象类,该类继承了Future,运行在ForkJoinPool线程池中。该类有三个是实现类:RecursiveAction、RecursiveTask、CountedCompleter(Java8新增)。

  其中,RecursiveAction无返回值,RecursiveTask有返回值,CountedCompleter和RecursiceAction功能类似,但在子任务阻塞或耗时较长时做了增强,它支持重写指定方法返回结果,也可不返回结果(返回null),回具体请浏览下文详解。
  需要注意的是,并不是任务拆分的越小越好,要根据业务需求,拆分成足够可以一次执行完成的任务,例如:需要一次查询5W条数据,可以拆分成拆分10个子任务,每个任务查询5K条,10个任务并发执行。

2. RecursiveAction

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;

import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveAction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * RecursiveAction 单元测试
 *
 * @author CL
 */
@Slf4j
public class RecursiveActionTest {

    /**
     * 测试斐波那契数列
     */
    @Test
    public void testFibonacci() {
        int n = 10;

        // 创建线程池,默认线程数为CPU核心数
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Fibonacci task = new Fibonacci(n);
        forkJoinPool.invoke(task);
        int result = task.getResult();

        log.info("第 {} 个斐波那契数为:{}", n, result);
    }

    /**
     * 计算斐波那契数列
     */
    @RequiredArgsConstructor
    private static class Fibonacci extends RecursiveAction {

        private final int n;
        @Getter
        private int result;

        @Override
        protected void compute() {
            if (n <= 1) {
                result = n;
                return;
            }
            Fibonacci f1 = new Fibonacci(n - 1);
            Fibonacci f2 = new Fibonacci(n - 2);
            ForkJoinTask.invokeAll(f1, f2);

            result = f1.getResult() + f2.getResult();
        }

    }

    /**
     * 测试一组数求和
     */
    @Test
    public void testSum() {
        // 构造测试数据
        int total = 5_0000;
        Random random = new Random();
        List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList());

        StopWatch stopWatch = new StopWatch();
        stopWatch.start("普通计算");

        int result1 = 0;
        for (Integer n : list) {
            result1 += n;
        }

        stopWatch.stop();

        stopWatch.start("Lambda计算");

        int result2 = list.stream().mapToInt(n -> n).sum();

        stopWatch.stop();

        stopWatch.start("ForkJoin计算");

        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Sum task = new Sum(list);
        forkJoinPool.invoke(task);
        int result3 = task.getResult();

        stopWatch.stop();

        stopWatch.start("ForkJoin非阻塞方式计算");

        ForkJoinPool forkJoinPool2 = new ForkJoinPool();
        Sum2 task2 = new Sum2(list, null);
        forkJoinPool2.invoke(task2);
        int result4 = task2.getResult();

        stopWatch.stop();

        Assert.isTrue(result1 == result2 && result1 == result3 && result1 == result4, "计算结果错误");

        for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) {
            log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis());
        }
    }

    /**
     * 一组数求和
     */
    @RequiredArgsConstructor
    private static class Sum extends RecursiveAction {

        private final static int THRESHOLD = 1000;
        private final List<Integer> list;
        @Getter
        private int result;

        @Override
        protected void compute() {
            int total = list.size();
            if (total <= THRESHOLD) {
                result = list.stream().mapToInt(n -> n).sum();
                return;
            }
            int middle = total / 2;
            Sum s1 = new Sum(list.subList(0, middle));
            Sum s2 = new Sum(list.subList(middle, total));
            ForkJoinTask.invokeAll(s1, s2);

            result = s1.getResult() + s2.getResult();
        }

    }

    /**
     * 一组数求和<br/>
     * 本级任务不阻塞等待子任务执行,而是循环获取未执行的任务执行
     */
    @RequiredArgsConstructor
    private static class Sum2 extends RecursiveAction {

        private final static int THRESHOLD = 1000;
        @Getter
        private final List<Integer> list;
        @Getter
        private final Sum2 subTask;
        @Getter
        private int result;

        @Override
        protected void compute() {
            int total = list.size();
            int start = 0;
            Sum2 tempTask = null;
            // 拆分任务
            while (total > THRESHOLD) {
                tempTask = new Sum2(list.subList(start, start + THRESHOLD), tempTask);
                tempTask.fork();

                start = start + THRESHOLD;
                total = total - THRESHOLD;
            }
            // 剩余最后一批任务
            int sum = list.subList(start, list.size()).stream().mapToInt(n -> n).sum();
            // 收集拆分任务结果
            while (tempTask != null) {
                if (tempTask.tryUnfork()) {
                    sum += tempTask.getList().stream().mapToInt(n -> n).sum();
                } else {
                    tempTask.join();
                    sum += tempTask.getResult();
                }
                tempTask = tempTask.getSubTask();
            }

            result = sum;
        }

    }

}

  斐波那契数列测试结果:

20:16:22.523 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 第 10 个斐波那契数为:55

  一组数求和测试结果:

20:28:56.607 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 普通计算 耗时:17 ms
20:28:56.614 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - Lambda计算 耗时:10 ms
20:28:56.615 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - ForkJoin计算 耗时:13 ms
20:28:56.615 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - ForkJoin非阻塞方式计算 耗时:10 ms

3. RecursiveTask

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;

import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * RecursiveTask 单元测试
 *
 * @author CL
 */
@Slf4j
public class RecursiveTaskTest {

    /**
     * 测试斐波那契数列
     */
    @Test
    public void testFibonacci() {
        int n = 10;

        // 创建线程池,默认线程数为CPU核心数
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Fibonacci task = new Fibonacci(n);
        int result = forkJoinPool.invoke(task);

        log.info("第 {} 个斐波那契数为:{}", n, result);
    }

    /**
     * 计算斐波那契数列
     */
    @RequiredArgsConstructor
    private static class Fibonacci extends RecursiveTask<Integer> {

        private final int n;

        @Override
        protected Integer compute() {
            if (n <= 1) {
                return n;
            }
            Fibonacci f1 = new Fibonacci(n - 1);
            f1.fork();
            Fibonacci f2 = new Fibonacci(n - 2);

            return f2.compute() + f1.join();
        }

    }

    /**
     * 测试一组数求和
     */
    @Test
    public void testSum() {
        // 构造测试数据
        int total = 5_0000;
        Random random = new Random();
        List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList());

        StopWatch stopWatch = new StopWatch();
        stopWatch.start("普通计算");

        int result1 = 0;
        for (Integer n : list) {
            result1 += n;
        }

        stopWatch.stop();

        stopWatch.start("Lambda计算");

        int result2 = list.stream().mapToInt(n -> n).sum();

        stopWatch.stop();

        stopWatch.start("ForkJoin计算");

        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Sum task = new Sum(list);
        int result3 = forkJoinPool.invoke(task);

        stopWatch.stop();

        stopWatch.start("ForkJoin非阻塞方式计算");

        ForkJoinPool forkJoinPool2 = new ForkJoinPool();
        Sum2 task2 = new Sum2(list, null);
        int result4 = forkJoinPool2.invoke(task2);

        stopWatch.stop();

        Assert.isTrue(result1 == result2 && result1 == result3 && result1 == result4, "计算结果错误");

        for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) {
            log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis());
        }
    }

    /**
     * 一组数求和
     */
    @RequiredArgsConstructor
    private static class Sum extends RecursiveTask<Integer> {

        private final static int THRESHOLD = 1000;
        private final List<Integer> list;

        @Override
        protected Integer compute() {
            int total = list.size();
            if (total <= THRESHOLD) {
                return list.stream().mapToInt(n -> n).sum();
            }
            int middle = total / 2;
            Sum s1 = new Sum(list.subList(0, middle));
            s1.fork();
            Sum s2 = new Sum(list.subList(middle, total));

            return s2.compute() + s1.join();
        }

    }

    /**
     * 一组数求和<br/>
     * 本级任务不阻塞等待子任务执行,而是循环获取未执行的任务执行
     */
    @RequiredArgsConstructor
    private static class Sum2 extends RecursiveTask<Integer> {

        private final static int THRESHOLD = 1000;
        @Getter
        private final List<Integer> list;
        @Getter
        private final Sum2 subTask;

        @Override
        protected Integer compute() {
            int total = list.size();
            int start = 0;
            Sum2 tempTask = null;
            // 拆分任务
            while (total > THRESHOLD) {
                tempTask = new Sum2(list.subList(start, start + THRESHOLD), tempTask);
                tempTask.fork();

                start = start + THRESHOLD;
                total = total - THRESHOLD;
            }
            // 剩余最后一批任务
            int sum = list.subList(start, list.size()).stream().mapToInt(n -> n).sum();
            // 收集拆分任务结果
            while (tempTask != null) {
                if (tempTask.tryUnfork()) {
                    sum += tempTask.getList().stream().mapToInt(n -> n).sum();
                } else {
                    sum += tempTask.join();
                }
                tempTask = tempTask.getSubTask();
            }

            return sum;
        }

    }

}

  斐波那契数列测试结果:

21:54:32.523 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 第 10 个斐波那契数为:55

  一组数求和测试结果:

21:54:33.092 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - 普通计算 耗时:7 ms
21:54:33.096 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - Lambda计算 耗时:6 ms
21:54:33.097 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - ForkJoin计算 耗时:9 ms
21:54:33.097 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - ForkJoin非阻塞方式计算 耗时:6 ms

4. CountedCompleter

  CountedCompleter是Java8推出的ForkJoinTask实现类,相比于Java7退出的以上两个实现类来说,它在子任务耗时过长或阻塞时更加健壮,但是理解起来比较困难。
  该类有两个成员变量:completer保存当前任务的父任务,如果为null,则表示顶级任务;pending保存当前任务的子任务数,每当子任务执行完成则pending会减1,当pending==0时,表示当前任务执行完成。

public abstract class CountedCompleter<T> extends ForkJoinTask<T> {
    private static final long serialVersionUID = 5232453752276485070L;

    /** This task's completer, or null if none */
    final CountedCompleter<?> completer;
    /** The number of pending tasks until completion */
    volatile int pending;
    

    ......
}

  构造方法:

方法及参数 描述
CountedCompleter() 无父级任务,初始化子任务数为0
CountedCompleter(CountedCompleter<?> completer) 指定父级任务,初始化子任务数为0
CountedCompleter(CountedCompleter<?> completer, int initialPendingCount) 指定父级任务,初始化子任务数

  常用方法:

方法及参数 描述
abstract void compute() 执行的计算任务,必须实现
void onCompletion(CountedCompleter<?> caller) 调用tryComplete后,如果当前任务的所有子任务都完成(当前任务完成),则会调用该方法处理完成后的业务
boolean onExceptionalCompletion(Throwable ex, CountedCompleter<?> caller) 是否向父级任务传递异常,默认传递。即当前任务在执行compute()方法时抛出异常,或者显式的调用completeExceptionally(Throwable ex)抛出异常时,为true则父任务也异常完成,为false不会传递到父级任务
CountedCompleter<?> getCompleter() 返回当前任务的父级任务,没有返回null
int getPendingCount() 获取子任务数
void setPendingCount(int count) 添加子任务数(非原子性)
void addToPendingCount(int delta) 添加子任务数(原子性)
CountedCompleter<?> getRoot() 返回当前任务的父级任务,没有返回自身
void tryComplete() 尝试完成任务。如果当前任务的所有子任务全部完成(pending==0),调用onCompletion(CountedCompleter<?> caller)方法执行完成后的逻辑,并将当前任务的线程状态设置为NORMAL,通知被当前任务阻塞的任务执行,否则通过自旋tr将pending减1。
void propagateCompletion() 尝试完成任务。和tryComplete()的区别是:当所有子任务执行完成后,不会调用onCompletion(CountedCompleter<?> caller)方法。
void complete(T rawResult) 无论当前任务是否完成,直接设置结果为指定 的结果,调用onCompletion(CountedCompleter<?> caller)方法,将当前任务的线程的状态设置为NORMAL。如果存在父级任务,并调用tryComplete()尝试结束子任务。
CountedCompleter<?> firstComplete() 如果当前任务已经完成,则返回当前任务;否则将pending减1,并返回null。通常使用在只需要获取任一任务结果时使用。
CountedCompleter<?> nextComplete() 如果当前任务存在父级任务,则调用父级任务的firstComplete()方法,否则将当前任务的线程的状态设置为NORMAL,并返回null。
void quietlyCompleteRoot() 相当于getRoot().quietlyComplete()
void helpComplete(int maxTasks) 当前任务阻塞等待任务执行时,尝试帮助当前任务执行其他未处理的任务。
void internalPropagateException(Throwable ex) 如果当前任务执行异常,并需要想父级任务传递时,循环传递异常只到最顶级任务
boolean exec() 执行compute()方法,并返回false。ForkJoinPool调用该方法执行任务,如果返回true,ForkJoinPool会将当前任务标记为完成,并通知被当前任务阻塞的其他线程,这也是和RecursiveAction、RecursiveTask区别所在
T getRawResult() 获取结果。方法内部为空,需要重写
void setRawResult(T t) 设置结果。方法内部为空,需要重写
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;

import java.util.List;
import java.util.Random;
import java.util.concurrent.CountedCompleter;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * CountedCompleter 单元测试
 *
 * @author CL
 */
@Slf4j
public class CountedCompleterTest {

    /**
     * 测试一组数求和
     */
    @Test
    public void testSum() {
        // 构造测试数据
        int total = 5_0000;
        Random random = new Random();
        List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList());

        StopWatch stopWatch = new StopWatch();
        stopWatch.start("普通计算");

        int result1 = 0;
        for (Integer n : list) {
            result1 += n;
        }

        stopWatch.stop();

        stopWatch.start("Lambda计算");

        int result2 = list.stream().mapToInt(n -> n).sum();

        stopWatch.stop();

        stopWatch.start("ForkJoin计算");

        Sum task = new Sum(null, list, new AtomicReference<>(0));
        int result3 = ForkJoinPool.commonPool().invoke(task);

        stopWatch.stop();

        Assert.isTrue(result1 == result2 && result1 == result3, "计算结果错误");

        for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) {
            log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis());
        }
    }

    /**
     * 一组数求和
     */
    private static class Sum extends CountedCompleter<Integer> {

        private final static int THRESHOLD = 1000;
        private List<Integer> list;
        private AtomicReference<Integer> result;

        public Sum(CountedCompleter<Integer> parentTask, List<Integer> list, AtomicReference<Integer> result) {
            super(parentTask);
            this.list = list;
            this.result = result;
        }

        @Override
        public void compute() {
            int total = list.size();
            if (total <= THRESHOLD) {
                result.getAndAccumulate(list.stream().mapToInt(n -> n).sum(), (a, b) -> a + b);

                // 不需要onCompletion(CountedCompleter)方法处理时,可以使用
                propagateCompletion();

                // 期望每个任务完成后继续执行onCompletion(CountedCompleter)方法时,可以使用
//                tryComplete();

                return;
            }
            int middle = total - THRESHOLD;
            List<Integer> subList = list.subList(middle, total);
            list = list.subList(0, middle);
            addToPendingCount(1);
            Sum s1 = new Sum(this, subList, result);
            s1.fork();

            // 继续执行
            this.exec();
        }

        @Override
        public Integer getRawResult() {
            return result.get();
        }

    }

}

  一组数求和测试结果:

22:54:17.123 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - 普通计算 耗时:9 ms
22:54:17.131 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - Lambda计算 耗时:8 ms
22:54:17.131 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - ForkJoin计算 耗时:9 ms
posted @ 2023-04-07 22:01  C3Stones  阅读(294)  评论(0编辑  收藏  举报