forkjoin框架
forkjoin框架
一:简介
从JDK1.7开始,Java提供Fork/Join框架用于并行执行任务,它的思想就是讲一个大任务分割成若干小任务,最终汇总每个小任务的结果得到这个大任务的结果。
这种思想和MapReduce很像(input --> split --> map --> reduce --> output)
主要有两步:
- 第一、任务切分;
- 第二、结果合并
它的模型大致是这样的:线程池中的每个线程都有自己的工作队列(PS:这一点和ThreadPoolExecutor不同,ThreadPoolExecutor是所有线程公用一个工作队列,所有线程都从这个工作队列中取任务),当自己队列中的任务都完成以后,会从其它线程的工作队列中偷一个任务执行,这样可以充分利用资源。下面盗一张图来宏观展示一下:
二:forkjoin中定义的角色
ForkJoinPool:充当fork/join框架里面的管理者,最原始的任务都要交给它才能处理。它负责控制整个fork/join有多少个workerThread,workerThread的创建,激活都是由它来掌控。它还负责workQueue队列的创建和分配,每当创建一个workerThread,为它负责分配相应的workQueue。然后它把接到的活都交给workerThread去处理,它可以说是整个frok/join的容器。
ForkJoinWorkerThread:fork/join里面真正干活的"工人",本质是一个线程。里面有一个ForkJoinPool.WorkQueue的队列存放着它要干的活,接活之前它要向ForkJoinPool注册(registerWorker),拿到相应的workQueue。然后就从workQueue里面拿任务出来处理。它是依附于ForkJoinPool而存活,如果ForkJoinPool的销毁了,它也会跟着结束。
我们所说的forkjoin的工作窃取,那么究竟是怎么窃取的呢?我么你分析一下任务是由workThread来窃取的,workThread是一个线程,线程的执行逻辑都是在run里面,所以任务的窃取逻辑一定在run()中可以找的到。
public void run() { //线程run方法
if (workQueue.array == null) { // only run once
Throwable exception = null;
try {
onStart();
pool.runWorker(workQueue); //在这里处理任务队列!
} catch (Throwable ex) {
exception = ex;
} finally {
try {
onTermination(exception);
} catch (Throwable ex) {
if (exception == null)
exception = ex;
} finally {
pool.deregisterWorker(this, exception);
}
}
}
}
final void runWorker(WorkQueue w) {
w.growArray(); // allocate queue 进行队列的初始化
int seed = w.hint; // initially holds randomization hint
int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
for (ForkJoinTask<?> t;;) { //又是无限循环处理任务!
if ((t = scan(w, r)) != null) //在这里获取任务!
w.runTask(t);
else if (!awaitWork(w, r))
break;
r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
}
}
窃取逻辑的主要代码:(scan)任务的窃取从workerThread运行的那一刻就开始了,先随机选中一条队列看看能不能窃取到下一条队列,如果都窃取不到就返回null。
/**
* Scans for and tries to steal a top-level task. Scans start at a
* random location, randomly moving on apparent contention,
* otherwise continuing linearly until reaching two consecutive
* empty passes over all queues with the same checksum (summing
* each base index of each queue, that moves on each steal), at
* which point the worker tries to inactivate and then re-scans,
* attempting to re-activate (itself or some other worker) if
* finding a task; otherwise returning null to await work. Scans
* otherwise touch as little memory as possible, to reduce
* disruption on other scanning threads.
*
* @param w the worker (via its WorkQueue)
* @param r a random seed
* @return a task, or null if none found
*/
private ForkJoinTask<?> scan(WorkQueue w, int r) {
WorkQueue[] ws; int m;
if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
int ss = w.scanState; // initially non-negative
for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
int b, n; long c;
if ((q = ws[k]) != null) { //随机选中了非空队列 q
if ((n = (b = q.base) - q.top) < 0 &&
(a = q.array) != null) { // non-empty
long i = (((a.length - 1) & b) << ASHIFT) + ABASE; //从尾部出队,b是尾部下标
if ((t = ((ForkJoinTask<?>)
U.getObjectVolatile(a, i))) != null &&
q.base == b) {
if (ss >= 0) {
if (U.compareAndSwapObject(a, i, t, null)) { //利用cas出队
q.base = b + 1;
if (n < -1) // signal others
signalWork(ws, q);
return t; //出队成功,成功窃取一个任务!
}
}
else if (oldSum == 0 && // try to activate 队列没有激活,尝试激活
w.scanState < 0)
tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
}
if (ss < 0) // refresh
ss = w.scanState;
r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
origin = k = r & m; // move and rescan
oldSum = checkSum = 0;
continue;
}
checkSum += b;
}<br> //k = k + 1表示取下一个队列 如果(k + 1) & m == origin表示 已经遍历完所有队列了
if ((k = (k + 1) & m) == origin) { // continue until stable
if ((ss >= 0 || (ss == (ss = w.scanState))) &&
oldSum == (oldSum = checkSum)) {
if (ss < 0 || w.qlock < 0) // already inactive
break;
int ns = ss | INACTIVE; // try to inactivate
long nc = ((SP_MASK & ns) |
(UC_MASK & ((c = ctl) - AC_UNIT)));
w.stackPred = (int)c; // hold prev stack top
U.putInt(w, QSCANSTATE, ns);
if (U.compareAndSwapLong(this, CTL, c, nc))
ss = ns;
else
w.scanState = ss; // back out
}
checkSum = 0;
}
}
}
return null;
}
ForkJoinPool.WorkQueue: 双端队列就是它,它负责存储接收的任务。
ForkJoinTask:代表fork/join里面任务类型,我们一般用它的两个子类RecursiveTask、RecursiveAction。这两个区别在于RecursiveTask任务是有返回值,RecursiveAction没有返回值。任务的处理逻辑包括任务的切分都集中在compute()方法里面。
此外还有fork()方法:在当前线程运行的线程池中安排一个异步执行。简单的理解就是在创建一个子任务。
join()方法:当任务完成的时候返回计算结果。
invoke()方法:开始执行任务,如果必要,等待计算完成。
RecursiveAction ()方法:一个递归无结果的ForkJoinTask(没有返回值)
RecursiveTask 一个递归有结果的ForkJoinTask(有返回值)
三:工作窃取
工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。工作窃取的运行流程图如下:
那么为什么需要使用工作窃取算法呢?
假如我们需要做一个比较大的任务,我们可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,于是把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应,比如A线程负责处理A队列里的任务。但是有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。
工作窃取算法的优点是充分利用线程进行并行计算,并减少了线程间的竞争,其缺点是在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。
使用示例:
package com.duoxiancheng.juc;
import java.util.concurrent.RecursiveTask;
public class ForkJoinDemo extends RecursiveTask<Long> {
private Long start;
private Long end;
/** 临界值 */
private static final Long temp = 10000L;
public ForkJoinDemo(Long start,Long end) {
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
/** 超过中间值,就分配任务 */
if(end - start < temp) {
Long sum = 0L;
for(Long i = start;i <= end;i++) {
sum += i;
}
return sum;
} else {
/** 获取中间值 */
Long middle = (end + start) / 2;
ForkJoinDemo right = new ForkJoinDemo(start,middle);
/** 开启分支计算任务 */
right.fork();
ForkJoinDemo left = new ForkJoinDemo(middle+1,end);
/** 开启分支计算任务 */
left.fork();
/** 合并结果 */
return right.join() + left.join();
}
}
}
package com.cjs.boot.demo;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
public class ForkJoinTaskDemo {
private class SumTask extends RecursiveTask<Integer> {
private static final int THRESHOLD = 20;
private int arr[];
private int start;
private int end;
public SumTask(int[] arr, int start, int end) {
this.arr = arr;
this.start = start;
this.end = end;
}
/**
* 小计
*/
private Integer subtotal() {
Integer sum = 0;
for (int i = start; i < end; i++) {
sum += arr[i];
}
System.out.println(Thread.currentThread().getName() + ": ∑(" + start + "~" + end + ")=" + sum);
return sum;
}
@Override
protected Integer compute() {
if ((end - start) <= THRESHOLD) {
return subtotal();
}else {
int middle = (start + end) / 2;
SumTask left = new SumTask(arr, start, middle);
SumTask right = new SumTask(arr, middle, end);
left.fork();
right.fork();
return left.join() + right.join();
}
}
}
public static void main(String[] args) throws ExecutionException, InterruptedException {
int[] arr = new int[100];
for (int i = 0; i < 100; i++) {
arr[i] = i + 1;
}
// 创建一个ForkJoinPool线程池,用来存放任务
ForkJoinPool pool = new ForkJoinPool();
// 返回有结果的任务,RecursiveTask
ForkJoinTask<Integer> result = pool.submit(new ForkJoinTaskDemo().new SumTask(arr, 0, arr.length));
System.out.println("最终计算结果: " + result.invoke());
pool.shutdown();
}
}
参考链接:
https://www.cnblogs.com/cjsblog/p/9078341.html
https://www.cnblogs.com/linlinismine/p/9295701.html