【并发编程】Java7 - ForkJoin,将大任务拆分成小任务

1. 简介#

  Java7提供了可以将大任务拆分成小任务执行再合并结果的框架——Fork/Join。其中,将大任务拆分成足够执行的小任务并发执行的过程称为Fork,将这些小任务结果整合后形成最终的结果的过程称为Join。

  Fork/Join框架的具体体现为ForkJoinTask抽象类,该类继承了Future,运行在ForkJoinPool线程池中。该类有三个是实现类:RecursiveAction、RecursiveTask、CountedCompleter(Java8新增)。

  其中,RecursiveAction无返回值,RecursiveTask有返回值,CountedCompleter和RecursiceAction功能类似,但在子任务阻塞或耗时较长时做了增强,它支持重写指定方法返回结果,也可不返回结果(返回null),回具体请浏览下文详解。
  需要注意的是,并不是任务拆分的越小越好,要根据业务需求,拆分成足够可以一次执行完成的任务,例如:需要一次查询5W条数据,可以拆分成拆分10个子任务,每个任务查询5K条,10个任务并发执行。

2. RecursiveAction#

Copy
import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; import org.springframework.util.Assert; import org.springframework.util.StopWatch; import java.util.List; import java.util.Random; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinTask; import java.util.concurrent.RecursiveAction; import java.util.stream.Collectors; import java.util.stream.IntStream; /** * RecursiveAction 单元测试 * * @author CL */ @Slf4j public class RecursiveActionTest { /** * 测试斐波那契数列 */ @Test public void testFibonacci() { int n = 10; // 创建线程池,默认线程数为CPU核心数 ForkJoinPool forkJoinPool = new ForkJoinPool(); Fibonacci task = new Fibonacci(n); forkJoinPool.invoke(task); int result = task.getResult(); log.info("第 {} 个斐波那契数为:{}", n, result); } /** * 计算斐波那契数列 */ @RequiredArgsConstructor private static class Fibonacci extends RecursiveAction { private final int n; @Getter private int result; @Override protected void compute() { if (n <= 1) { result = n; return; } Fibonacci f1 = new Fibonacci(n - 1); Fibonacci f2 = new Fibonacci(n - 2); ForkJoinTask.invokeAll(f1, f2); result = f1.getResult() + f2.getResult(); } } /** * 测试一组数求和 */ @Test public void testSum() { // 构造测试数据 int total = 5_0000; Random random = new Random(); List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList()); StopWatch stopWatch = new StopWatch(); stopWatch.start("普通计算"); int result1 = 0; for (Integer n : list) { result1 += n; } stopWatch.stop(); stopWatch.start("Lambda计算"); int result2 = list.stream().mapToInt(n -> n).sum(); stopWatch.stop(); stopWatch.start("ForkJoin计算"); ForkJoinPool forkJoinPool = new ForkJoinPool(); Sum task = new Sum(list); forkJoinPool.invoke(task); int result3 = task.getResult(); stopWatch.stop(); stopWatch.start("ForkJoin非阻塞方式计算"); ForkJoinPool forkJoinPool2 = new ForkJoinPool(); Sum2 task2 = new Sum2(list, null); forkJoinPool2.invoke(task2); int result4 = task2.getResult(); stopWatch.stop(); Assert.isTrue(result1 == result2 && result1 == result3 && result1 == result4, "计算结果错误"); for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) { log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis()); } } /** * 一组数求和 */ @RequiredArgsConstructor private static class Sum extends RecursiveAction { private final static int THRESHOLD = 1000; private final List<Integer> list; @Getter private int result; @Override protected void compute() { int total = list.size(); if (total <= THRESHOLD) { result = list.stream().mapToInt(n -> n).sum(); return; } int middle = total / 2; Sum s1 = new Sum(list.subList(0, middle)); Sum s2 = new Sum(list.subList(middle, total)); ForkJoinTask.invokeAll(s1, s2); result = s1.getResult() + s2.getResult(); } } /** * 一组数求和<br/> * 本级任务不阻塞等待子任务执行,而是循环获取未执行的任务执行 */ @RequiredArgsConstructor private static class Sum2 extends RecursiveAction { private final static int THRESHOLD = 1000; @Getter private final List<Integer> list; @Getter private final Sum2 subTask; @Getter private int result; @Override protected void compute() { int total = list.size(); int start = 0; Sum2 tempTask = null; // 拆分任务 while (total > THRESHOLD) { tempTask = new Sum2(list.subList(start, start + THRESHOLD), tempTask); tempTask.fork(); start = start + THRESHOLD; total = total - THRESHOLD; } // 剩余最后一批任务 int sum = list.subList(start, list.size()).stream().mapToInt(n -> n).sum(); // 收集拆分任务结果 while (tempTask != null) { if (tempTask.tryUnfork()) { sum += tempTask.getList().stream().mapToInt(n -> n).sum(); } else { tempTask.join(); sum += tempTask.getResult(); } tempTask = tempTask.getSubTask(); } result = sum; } } }

  斐波那契数列测试结果:

Copy
20:16:22.523 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 第 10 个斐波那契数为:55

  一组数求和测试结果:

Copy
20:28:56.607 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 普通计算 耗时:17 ms 20:28:56.614 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - Lambda计算 耗时:10 ms 20:28:56.615 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - ForkJoin计算 耗时:13 ms 20:28:56.615 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - ForkJoin非阻塞方式计算 耗时:10 ms

3. RecursiveTask#

Copy
import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; import org.springframework.util.Assert; import org.springframework.util.StopWatch; import java.util.List; import java.util.Random; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.RecursiveTask; import java.util.stream.Collectors; import java.util.stream.IntStream; /** * RecursiveTask 单元测试 * * @author CL */ @Slf4j public class RecursiveTaskTest { /** * 测试斐波那契数列 */ @Test public void testFibonacci() { int n = 10; // 创建线程池,默认线程数为CPU核心数 ForkJoinPool forkJoinPool = new ForkJoinPool(); Fibonacci task = new Fibonacci(n); int result = forkJoinPool.invoke(task); log.info("第 {} 个斐波那契数为:{}", n, result); } /** * 计算斐波那契数列 */ @RequiredArgsConstructor private static class Fibonacci extends RecursiveTask<Integer> { private final int n; @Override protected Integer compute() { if (n <= 1) { return n; } Fibonacci f1 = new Fibonacci(n - 1); f1.fork(); Fibonacci f2 = new Fibonacci(n - 2); return f2.compute() + f1.join(); } } /** * 测试一组数求和 */ @Test public void testSum() { // 构造测试数据 int total = 5_0000; Random random = new Random(); List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList()); StopWatch stopWatch = new StopWatch(); stopWatch.start("普通计算"); int result1 = 0; for (Integer n : list) { result1 += n; } stopWatch.stop(); stopWatch.start("Lambda计算"); int result2 = list.stream().mapToInt(n -> n).sum(); stopWatch.stop(); stopWatch.start("ForkJoin计算"); ForkJoinPool forkJoinPool = new ForkJoinPool(); Sum task = new Sum(list); int result3 = forkJoinPool.invoke(task); stopWatch.stop(); stopWatch.start("ForkJoin非阻塞方式计算"); ForkJoinPool forkJoinPool2 = new ForkJoinPool(); Sum2 task2 = new Sum2(list, null); int result4 = forkJoinPool2.invoke(task2); stopWatch.stop(); Assert.isTrue(result1 == result2 && result1 == result3 && result1 == result4, "计算结果错误"); for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) { log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis()); } } /** * 一组数求和 */ @RequiredArgsConstructor private static class Sum extends RecursiveTask<Integer> { private final static int THRESHOLD = 1000; private final List<Integer> list; @Override protected Integer compute() { int total = list.size(); if (total <= THRESHOLD) { return list.stream().mapToInt(n -> n).sum(); } int middle = total / 2; Sum s1 = new Sum(list.subList(0, middle)); s1.fork(); Sum s2 = new Sum(list.subList(middle, total)); return s2.compute() + s1.join(); } } /** * 一组数求和<br/> * 本级任务不阻塞等待子任务执行,而是循环获取未执行的任务执行 */ @RequiredArgsConstructor private static class Sum2 extends RecursiveTask<Integer> { private final static int THRESHOLD = 1000; @Getter private final List<Integer> list; @Getter private final Sum2 subTask; @Override protected Integer compute() { int total = list.size(); int start = 0; Sum2 tempTask = null; // 拆分任务 while (total > THRESHOLD) { tempTask = new Sum2(list.subList(start, start + THRESHOLD), tempTask); tempTask.fork(); start = start + THRESHOLD; total = total - THRESHOLD; } // 剩余最后一批任务 int sum = list.subList(start, list.size()).stream().mapToInt(n -> n).sum(); // 收集拆分任务结果 while (tempTask != null) { if (tempTask.tryUnfork()) { sum += tempTask.getList().stream().mapToInt(n -> n).sum(); } else { sum += tempTask.join(); } tempTask = tempTask.getSubTask(); } return sum; } } }

  斐波那契数列测试结果:

Copy
21:54:32.523 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 第 10 个斐波那契数为:55

  一组数求和测试结果:

Copy
21:54:33.092 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - 普通计算 耗时:7 ms 21:54:33.096 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - Lambda计算 耗时:6 ms 21:54:33.097 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - ForkJoin计算 耗时:9 ms 21:54:33.097 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - ForkJoin非阻塞方式计算 耗时:6 ms

4. CountedCompleter#

  CountedCompleter是Java8推出的ForkJoinTask实现类,相比于Java7退出的以上两个实现类来说,它在子任务耗时过长或阻塞时更加健壮,但是理解起来比较困难。
  该类有两个成员变量:completer保存当前任务的父任务,如果为null,则表示顶级任务;pending保存当前任务的子任务数,每当子任务执行完成则pending会减1,当pending==0时,表示当前任务执行完成。

Copy
public abstract class CountedCompleter<T> extends ForkJoinTask<T> { private static final long serialVersionUID = 5232453752276485070L; /** This task's completer, or null if none */ final CountedCompleter<?> completer; /** The number of pending tasks until completion */ volatile int pending; ...... }

  构造方法:

方法及参数 描述
CountedCompleter() 无父级任务,初始化子任务数为0
CountedCompleter(CountedCompleter<?> completer) 指定父级任务,初始化子任务数为0
CountedCompleter(CountedCompleter<?> completer, int initialPendingCount) 指定父级任务,初始化子任务数

  常用方法:

方法及参数 描述
abstract void compute() 执行的计算任务,必须实现
void onCompletion(CountedCompleter<?> caller) 调用tryComplete后,如果当前任务的所有子任务都完成(当前任务完成),则会调用该方法处理完成后的业务
boolean onExceptionalCompletion(Throwable ex, CountedCompleter<?> caller) 是否向父级任务传递异常,默认传递。即当前任务在执行compute()方法时抛出异常,或者显式的调用completeExceptionally(Throwable ex)抛出异常时,为true则父任务也异常完成,为false不会传递到父级任务
CountedCompleter<?> getCompleter() 返回当前任务的父级任务,没有返回null
int getPendingCount() 获取子任务数
void setPendingCount(int count) 添加子任务数(非原子性)
void addToPendingCount(int delta) 添加子任务数(原子性)
CountedCompleter<?> getRoot() 返回当前任务的父级任务,没有返回自身
void tryComplete() 尝试完成任务。如果当前任务的所有子任务全部完成(pending==0),调用onCompletion(CountedCompleter<?> caller)方法执行完成后的逻辑,并将当前任务的线程状态设置为NORMAL,通知被当前任务阻塞的任务执行,否则通过自旋tr将pending减1。
void propagateCompletion() 尝试完成任务。和tryComplete()的区别是:当所有子任务执行完成后,不会调用onCompletion(CountedCompleter<?> caller)方法。
void complete(T rawResult) 无论当前任务是否完成,直接设置结果为指定 的结果,调用onCompletion(CountedCompleter<?> caller)方法,将当前任务的线程的状态设置为NORMAL。如果存在父级任务,并调用tryComplete()尝试结束子任务。
CountedCompleter<?> firstComplete() 如果当前任务已经完成,则返回当前任务;否则将pending减1,并返回null。通常使用在只需要获取任一任务结果时使用。
CountedCompleter<?> nextComplete() 如果当前任务存在父级任务,则调用父级任务的firstComplete()方法,否则将当前任务的线程的状态设置为NORMAL,并返回null。
void quietlyCompleteRoot() 相当于getRoot().quietlyComplete()
void helpComplete(int maxTasks) 当前任务阻塞等待任务执行时,尝试帮助当前任务执行其他未处理的任务。
void internalPropagateException(Throwable ex) 如果当前任务执行异常,并需要想父级任务传递时,循环传递异常只到最顶级任务
boolean exec() 执行compute()方法,并返回false。ForkJoinPool调用该方法执行任务,如果返回true,ForkJoinPool会将当前任务标记为完成,并通知被当前任务阻塞的其他线程,这也是和RecursiveAction、RecursiveTask区别所在
T getRawResult() 获取结果。方法内部为空,需要重写
void setRawResult(T t) 设置结果。方法内部为空,需要重写
Copy
import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; import org.springframework.util.Assert; import org.springframework.util.StopWatch; import java.util.List; import java.util.Random; import java.util.concurrent.CountedCompleter; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.IntStream; /** * CountedCompleter 单元测试 * * @author CL */ @Slf4j public class CountedCompleterTest { /** * 测试一组数求和 */ @Test public void testSum() { // 构造测试数据 int total = 5_0000; Random random = new Random(); List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList()); StopWatch stopWatch = new StopWatch(); stopWatch.start("普通计算"); int result1 = 0; for (Integer n : list) { result1 += n; } stopWatch.stop(); stopWatch.start("Lambda计算"); int result2 = list.stream().mapToInt(n -> n).sum(); stopWatch.stop(); stopWatch.start("ForkJoin计算"); Sum task = new Sum(null, list, new AtomicReference<>(0)); int result3 = ForkJoinPool.commonPool().invoke(task); stopWatch.stop(); Assert.isTrue(result1 == result2 && result1 == result3, "计算结果错误"); for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) { log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis()); } } /** * 一组数求和 */ private static class Sum extends CountedCompleter<Integer> { private final static int THRESHOLD = 1000; private List<Integer> list; private AtomicReference<Integer> result; public Sum(CountedCompleter<Integer> parentTask, List<Integer> list, AtomicReference<Integer> result) { super(parentTask); this.list = list; this.result = result; } @Override public void compute() { int total = list.size(); if (total <= THRESHOLD) { result.getAndAccumulate(list.stream().mapToInt(n -> n).sum(), (a, b) -> a + b); // 不需要onCompletion(CountedCompleter)方法处理时,可以使用 propagateCompletion(); // 期望每个任务完成后继续执行onCompletion(CountedCompleter)方法时,可以使用 // tryComplete(); return; } int middle = total - THRESHOLD; List<Integer> subList = list.subList(middle, total); list = list.subList(0, middle); addToPendingCount(1); Sum s1 = new Sum(this, subList, result); s1.fork(); // 继续执行 this.exec(); } @Override public Integer getRawResult() { return result.get(); } } }

  一组数求和测试结果:

Copy
22:54:17.123 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - 普通计算 耗时:9 ms 22:54:17.131 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - Lambda计算 耗时:8 ms 22:54:17.131 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - ForkJoin计算 耗时:9 ms
posted @   C3Stones  阅读(338)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
点击右上角即可分享
微信分享提示
CONTENTS