【并发编程】Java7 - ForkJoin,将大任务拆分成小任务
1. 简介
Java7提供了可以将大任务拆分成小任务执行再合并结果的框架——Fork/Join。其中,将大任务拆分成足够执行的小任务并发执行的过程称为Fork,将这些小任务结果整合后形成最终的结果的过程称为Join。
Fork/Join框架的具体体现为ForkJoinTask抽象类,该类继承了Future,运行在ForkJoinPool线程池中。该类有三个是实现类:RecursiveAction、RecursiveTask、CountedCompleter(Java8新增)。
其中,RecursiveAction无返回值,RecursiveTask有返回值,CountedCompleter和RecursiceAction功能类似,但在子任务阻塞或耗时较长时做了增强,它支持重写指定方法返回结果,也可不返回结果(返回null),回具体请浏览下文详解。
需要注意的是,并不是任务拆分的越小越好,要根据业务需求,拆分成足够可以一次执行完成的任务,例如:需要一次查询5W条数据,可以拆分成拆分10个子任务,每个任务查询5K条,10个任务并发执行。
2. RecursiveAction
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveAction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* RecursiveAction 单元测试
*
* @author CL
*/
@Slf4j
public class RecursiveActionTest {
/**
* 测试斐波那契数列
*/
@Test
public void testFibonacci() {
int n = 10;
// 创建线程池,默认线程数为CPU核心数
ForkJoinPool forkJoinPool = new ForkJoinPool();
Fibonacci task = new Fibonacci(n);
forkJoinPool.invoke(task);
int result = task.getResult();
log.info("第 {} 个斐波那契数为:{}", n, result);
}
/**
* 计算斐波那契数列
*/
@RequiredArgsConstructor
private static class Fibonacci extends RecursiveAction {
private final int n;
@Getter
private int result;
@Override
protected void compute() {
if (n <= 1) {
result = n;
return;
}
Fibonacci f1 = new Fibonacci(n - 1);
Fibonacci f2 = new Fibonacci(n - 2);
ForkJoinTask.invokeAll(f1, f2);
result = f1.getResult() + f2.getResult();
}
}
/**
* 测试一组数求和
*/
@Test
public void testSum() {
// 构造测试数据
int total = 5_0000;
Random random = new Random();
List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList());
StopWatch stopWatch = new StopWatch();
stopWatch.start("普通计算");
int result1 = 0;
for (Integer n : list) {
result1 += n;
}
stopWatch.stop();
stopWatch.start("Lambda计算");
int result2 = list.stream().mapToInt(n -> n).sum();
stopWatch.stop();
stopWatch.start("ForkJoin计算");
ForkJoinPool forkJoinPool = new ForkJoinPool();
Sum task = new Sum(list);
forkJoinPool.invoke(task);
int result3 = task.getResult();
stopWatch.stop();
stopWatch.start("ForkJoin非阻塞方式计算");
ForkJoinPool forkJoinPool2 = new ForkJoinPool();
Sum2 task2 = new Sum2(list, null);
forkJoinPool2.invoke(task2);
int result4 = task2.getResult();
stopWatch.stop();
Assert.isTrue(result1 == result2 && result1 == result3 && result1 == result4, "计算结果错误");
for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) {
log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis());
}
}
/**
* 一组数求和
*/
@RequiredArgsConstructor
private static class Sum extends RecursiveAction {
private final static int THRESHOLD = 1000;
private final List<Integer> list;
@Getter
private int result;
@Override
protected void compute() {
int total = list.size();
if (total <= THRESHOLD) {
result = list.stream().mapToInt(n -> n).sum();
return;
}
int middle = total / 2;
Sum s1 = new Sum(list.subList(0, middle));
Sum s2 = new Sum(list.subList(middle, total));
ForkJoinTask.invokeAll(s1, s2);
result = s1.getResult() + s2.getResult();
}
}
/**
* 一组数求和<br/>
* 本级任务不阻塞等待子任务执行,而是循环获取未执行的任务执行
*/
@RequiredArgsConstructor
private static class Sum2 extends RecursiveAction {
private final static int THRESHOLD = 1000;
@Getter
private final List<Integer> list;
@Getter
private final Sum2 subTask;
@Getter
private int result;
@Override
protected void compute() {
int total = list.size();
int start = 0;
Sum2 tempTask = null;
// 拆分任务
while (total > THRESHOLD) {
tempTask = new Sum2(list.subList(start, start + THRESHOLD), tempTask);
tempTask.fork();
start = start + THRESHOLD;
total = total - THRESHOLD;
}
// 剩余最后一批任务
int sum = list.subList(start, list.size()).stream().mapToInt(n -> n).sum();
// 收集拆分任务结果
while (tempTask != null) {
if (tempTask.tryUnfork()) {
sum += tempTask.getList().stream().mapToInt(n -> n).sum();
} else {
tempTask.join();
sum += tempTask.getResult();
}
tempTask = tempTask.getSubTask();
}
result = sum;
}
}
}
斐波那契数列测试结果:
20:16:22.523 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 第 10 个斐波那契数为:55
一组数求和测试结果:
20:28:56.607 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 普通计算 耗时:17 ms
20:28:56.614 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - Lambda计算 耗时:10 ms
20:28:56.615 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - ForkJoin计算 耗时:13 ms
20:28:56.615 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - ForkJoin非阻塞方式计算 耗时:10 ms
3. RecursiveTask
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* RecursiveTask 单元测试
*
* @author CL
*/
@Slf4j
public class RecursiveTaskTest {
/**
* 测试斐波那契数列
*/
@Test
public void testFibonacci() {
int n = 10;
// 创建线程池,默认线程数为CPU核心数
ForkJoinPool forkJoinPool = new ForkJoinPool();
Fibonacci task = new Fibonacci(n);
int result = forkJoinPool.invoke(task);
log.info("第 {} 个斐波那契数为:{}", n, result);
}
/**
* 计算斐波那契数列
*/
@RequiredArgsConstructor
private static class Fibonacci extends RecursiveTask<Integer> {
private final int n;
@Override
protected Integer compute() {
if (n <= 1) {
return n;
}
Fibonacci f1 = new Fibonacci(n - 1);
f1.fork();
Fibonacci f2 = new Fibonacci(n - 2);
return f2.compute() + f1.join();
}
}
/**
* 测试一组数求和
*/
@Test
public void testSum() {
// 构造测试数据
int total = 5_0000;
Random random = new Random();
List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList());
StopWatch stopWatch = new StopWatch();
stopWatch.start("普通计算");
int result1 = 0;
for (Integer n : list) {
result1 += n;
}
stopWatch.stop();
stopWatch.start("Lambda计算");
int result2 = list.stream().mapToInt(n -> n).sum();
stopWatch.stop();
stopWatch.start("ForkJoin计算");
ForkJoinPool forkJoinPool = new ForkJoinPool();
Sum task = new Sum(list);
int result3 = forkJoinPool.invoke(task);
stopWatch.stop();
stopWatch.start("ForkJoin非阻塞方式计算");
ForkJoinPool forkJoinPool2 = new ForkJoinPool();
Sum2 task2 = new Sum2(list, null);
int result4 = forkJoinPool2.invoke(task2);
stopWatch.stop();
Assert.isTrue(result1 == result2 && result1 == result3 && result1 == result4, "计算结果错误");
for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) {
log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis());
}
}
/**
* 一组数求和
*/
@RequiredArgsConstructor
private static class Sum extends RecursiveTask<Integer> {
private final static int THRESHOLD = 1000;
private final List<Integer> list;
@Override
protected Integer compute() {
int total = list.size();
if (total <= THRESHOLD) {
return list.stream().mapToInt(n -> n).sum();
}
int middle = total / 2;
Sum s1 = new Sum(list.subList(0, middle));
s1.fork();
Sum s2 = new Sum(list.subList(middle, total));
return s2.compute() + s1.join();
}
}
/**
* 一组数求和<br/>
* 本级任务不阻塞等待子任务执行,而是循环获取未执行的任务执行
*/
@RequiredArgsConstructor
private static class Sum2 extends RecursiveTask<Integer> {
private final static int THRESHOLD = 1000;
@Getter
private final List<Integer> list;
@Getter
private final Sum2 subTask;
@Override
protected Integer compute() {
int total = list.size();
int start = 0;
Sum2 tempTask = null;
// 拆分任务
while (total > THRESHOLD) {
tempTask = new Sum2(list.subList(start, start + THRESHOLD), tempTask);
tempTask.fork();
start = start + THRESHOLD;
total = total - THRESHOLD;
}
// 剩余最后一批任务
int sum = list.subList(start, list.size()).stream().mapToInt(n -> n).sum();
// 收集拆分任务结果
while (tempTask != null) {
if (tempTask.tryUnfork()) {
sum += tempTask.getList().stream().mapToInt(n -> n).sum();
} else {
sum += tempTask.join();
}
tempTask = tempTask.getSubTask();
}
return sum;
}
}
}
斐波那契数列测试结果:
21:54:32.523 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 第 10 个斐波那契数为:55
一组数求和测试结果:
21:54:33.092 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - 普通计算 耗时:7 ms
21:54:33.096 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - Lambda计算 耗时:6 ms
21:54:33.097 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - ForkJoin计算 耗时:9 ms
21:54:33.097 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - ForkJoin非阻塞方式计算 耗时:6 ms
4. CountedCompleter
CountedCompleter是Java8推出的ForkJoinTask实现类,相比于Java7退出的以上两个实现类来说,它在子任务耗时过长或阻塞时更加健壮,但是理解起来比较困难。
该类有两个成员变量:completer保存当前任务的父任务,如果为null,则表示顶级任务;pending保存当前任务的子任务数,每当子任务执行完成则pending会减1,当pending==0时,表示当前任务执行完成。
public abstract class CountedCompleter<T> extends ForkJoinTask<T> {
private static final long serialVersionUID = 5232453752276485070L;
/** This task's completer, or null if none */
final CountedCompleter<?> completer;
/** The number of pending tasks until completion */
volatile int pending;
......
}
构造方法:
方法及参数 | 描述 |
---|---|
CountedCompleter() | 无父级任务,初始化子任务数为0 |
CountedCompleter(CountedCompleter<?> completer) | 指定父级任务,初始化子任务数为0 |
CountedCompleter(CountedCompleter<?> completer, int initialPendingCount) | 指定父级任务,初始化子任务数 |
常用方法:
方法及参数 | 描述 |
---|---|
abstract void compute() | 执行的计算任务,必须实现 |
void onCompletion(CountedCompleter<?> caller) | 调用tryComplete后,如果当前任务的所有子任务都完成(当前任务完成),则会调用该方法处理完成后的业务 |
boolean onExceptionalCompletion(Throwable ex, CountedCompleter<?> caller) | 是否向父级任务传递异常,默认传递。即当前任务在执行compute()方法时抛出异常,或者显式的调用completeExceptionally(Throwable ex)抛出异常时,为true则父任务也异常完成,为false不会传递到父级任务 |
CountedCompleter<?> getCompleter() | 返回当前任务的父级任务,没有返回null |
int getPendingCount() | 获取子任务数 |
void setPendingCount(int count) | 添加子任务数(非原子性) |
void addToPendingCount(int delta) | 添加子任务数(原子性) |
CountedCompleter<?> getRoot() | 返回当前任务的父级任务,没有返回自身 |
void tryComplete() | 尝试完成任务。如果当前任务的所有子任务全部完成(pending==0),调用onCompletion(CountedCompleter<?> caller)方法执行完成后的逻辑,并将当前任务的线程状态设置为NORMAL,通知被当前任务阻塞的任务执行,否则通过自旋tr将pending减1。 |
void propagateCompletion() | 尝试完成任务。和tryComplete()的区别是:当所有子任务执行完成后,不会调用onCompletion(CountedCompleter<?> caller)方法。 |
void complete(T rawResult) | 无论当前任务是否完成,直接设置结果为指定 的结果,调用onCompletion(CountedCompleter<?> caller)方法,将当前任务的线程的状态设置为NORMAL。如果存在父级任务,并调用tryComplete()尝试结束子任务。 |
CountedCompleter<?> firstComplete() | 如果当前任务已经完成,则返回当前任务;否则将pending减1,并返回null。通常使用在只需要获取任一任务结果时使用。 |
CountedCompleter<?> nextComplete() | 如果当前任务存在父级任务,则调用父级任务的firstComplete()方法,否则将当前任务的线程的状态设置为NORMAL,并返回null。 |
void quietlyCompleteRoot() | 相当于getRoot().quietlyComplete() |
void helpComplete(int maxTasks) | 当前任务阻塞等待任务执行时,尝试帮助当前任务执行其他未处理的任务。 |
void internalPropagateException(Throwable ex) | 如果当前任务执行异常,并需要想父级任务传递时,循环传递异常只到最顶级任务 |
boolean exec() | 执行compute()方法,并返回false。ForkJoinPool调用该方法执行任务,如果返回true,ForkJoinPool会将当前任务标记为完成,并通知被当前任务阻塞的其他线程,这也是和RecursiveAction、RecursiveTask区别所在 |
T getRawResult() | 获取结果。方法内部为空,需要重写 |
void setRawResult(T t) | 设置结果。方法内部为空,需要重写 |
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountedCompleter;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* CountedCompleter 单元测试
*
* @author CL
*/
@Slf4j
public class CountedCompleterTest {
/**
* 测试一组数求和
*/
@Test
public void testSum() {
// 构造测试数据
int total = 5_0000;
Random random = new Random();
List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList());
StopWatch stopWatch = new StopWatch();
stopWatch.start("普通计算");
int result1 = 0;
for (Integer n : list) {
result1 += n;
}
stopWatch.stop();
stopWatch.start("Lambda计算");
int result2 = list.stream().mapToInt(n -> n).sum();
stopWatch.stop();
stopWatch.start("ForkJoin计算");
Sum task = new Sum(null, list, new AtomicReference<>(0));
int result3 = ForkJoinPool.commonPool().invoke(task);
stopWatch.stop();
Assert.isTrue(result1 == result2 && result1 == result3, "计算结果错误");
for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) {
log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis());
}
}
/**
* 一组数求和
*/
private static class Sum extends CountedCompleter<Integer> {
private final static int THRESHOLD = 1000;
private List<Integer> list;
private AtomicReference<Integer> result;
public Sum(CountedCompleter<Integer> parentTask, List<Integer> list, AtomicReference<Integer> result) {
super(parentTask);
this.list = list;
this.result = result;
}
@Override
public void compute() {
int total = list.size();
if (total <= THRESHOLD) {
result.getAndAccumulate(list.stream().mapToInt(n -> n).sum(), (a, b) -> a + b);
// 不需要onCompletion(CountedCompleter)方法处理时,可以使用
propagateCompletion();
// 期望每个任务完成后继续执行onCompletion(CountedCompleter)方法时,可以使用
// tryComplete();
return;
}
int middle = total - THRESHOLD;
List<Integer> subList = list.subList(middle, total);
list = list.subList(0, middle);
addToPendingCount(1);
Sum s1 = new Sum(this, subList, result);
s1.fork();
// 继续执行
this.exec();
}
@Override
public Integer getRawResult() {
return result.get();
}
}
}
一组数求和测试结果:
22:54:17.123 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - 普通计算 耗时:9 ms
22:54:17.131 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - Lambda计算 耗时:8 ms
22:54:17.131 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - ForkJoin计算 耗时:9 ms