摘要:本文主要学习了多线程的ForkJoin框架。
环境
Windows 10 企业版 LTSC 21H2 Java 1.8
1 简介 在JDK1.7之后引入了ForkJoin框架,将一个大任务分解成多个子任务,子任务可以继续往下分解,将多个子任务的结果合并成一个大结果,最终合并成大任务的结果。
ForkJoin框架要完成两件事情:
Fork:把大任务拆分成子任务。
Join:把子任务的结果合并成大任务的结果。
ForkJoin框架的实现非常复杂,内部大量运用了位操作和无锁算法,核心组件:
ForkJoinPool:基于工作窃取算法的线程池,负责全局任务调度与负载均衡。
ForkJoinTask:可递归Fork和Join的任务单元,自带状态机驱动完成通知。
ForkJoinWorkerThread:拥有独立队列的线程,优先执行本地任务,空闲时窃取外部队列任务。
WorkQueue:无锁双端队列,支持FIFO先进先出和LIFO后进先出,实现高效任务分发与窃取。
2 类和接口 2.1 ForkJoinPool ForkJoinPool是分支合并池,类似于线程池ThreadPoolExecutor类,同样是ExecutorService接口的一个实现类。
在ForkJoinPool类中提供了三个构造方法:
java 1 2 3 public ForkJoinPool () ;public ForkJoinPool (int parallelism) ;public ForkJoinPool (int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode) ;
最终调用的是下面这个私有构造器:
java 1 private ForkJoinPool (int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, int mode, String workerNamePrefix) ;
参数含义:
parallelism:并行级别,默认值为CPU核心数,ForkJoinPool里线程数量与该参数有关,但它不表示最大线程数。
factory:线程工厂,默认是DefaultForkJoinWorkerThreadFactory,其实就是用来创建ForkJoinWorkerThread线程对象。
handler:异常处理器。
mode:调度模式,FIFO_QUEUE表示本地队列先进先出,LIFO_QUEUE表示本地队列后进先出。
workerNamePrefix:线程的名称前缀。
成员变量:
config:创建ForkJoinPool的配置,int类型的变量,占32位内存:
低16位表示parallelism。
第17位表示mode,0表示队列后进先出,1表示队列先进先出。
第32位表示是否共享模式,0表示普通模式,1表示共享模式,队列没有线程,只能被其他线程窃取任务。
ctl:ForkJoinPool的主要控制字段,long类型的变量,占64位内存:
第63~48位表示激活线程数量,值为激活线程数减去parallelism(补码表示),线程激活则加1,线程停用则减1。当累积增加parallelism时第63位翻转为0,则不允许再激活线程。
第47~32位表示所有线程数量,值为所有线程数减去parallelism(补码表示),创建线程则加1,终止线程则减1。当累积增加parallelism时第47位翻转为0,则不允许再创建线程。
第31~0位表示非激活线程链中top线程的本地队列的scanState属性:
第15~0位表示非激活线程链中top线程的本地队列在workQueues数组中的索引。
第31~16位表示非激活线程链中top线程的版本计数和线程状态。
workQueues:WorkQueue数组,奇数索引的队列可以关联线程并接收线程提交的本地任务,偶数索引的队列只能接收外部任务。
factory:创建线程的工厂。
2.2 ForkJoinTask ForkJoinTask是Future接口的抽象实现类,提供了用于分解任务的fork()
方法和用于合并任务的join()
方法。
在ThreadPoolExecutor类中,线程池执行任务调用的execute()
方法中要求传入Runnable接口的实例。但是在ForkJoinPool类中,除了可以传入Runnable接口的实例外,还可以传入ForkJoinTask抽象类的实例,并且传入Runnable接口的实例也会被适配为ForkJoinTask抽象类的实例。
2.3 RecursiveTask 通常情况下使用ForkJoinTask抽象类的实例,并不需要直接继承ForkJoinTask类,只需要继承其子类:
RecursiveAction:用于没有返回结果的任务。
RecursiveTask:用于有返回结果的任务,最常用。
2.4 ForkJoinWorkerThread ForkJoinWorkerThread类是Thread的子类,作为线程池中的线程执行任务,其内部维护了一个WorkerQueue类型的双向任务队列。
线程在执行任务时,优先处理本地任务队列中的任务(支持FIFO和LIFO),当本地任务队列为空时,会窃取外部任务队列中的任务(FIFO)。
2.5 WorkerQueue WorkerQueue类是ForkJoinPool类的一个内部类,存储ForkJoinTask实例的双端队列。
3 源码 3.1 提交任务 调用ForkJoinPool类的submit()
方法提交任务,将任务添加到外部队列,唤醒线程执行任务:
java 1 2 3 4 5 6 public <T> ForkJoinTask<T> submit (ForkJoinTask<T> task) { if (task == null ) throw new NullPointerException (); externalPush(task); return task; }
3.2 分解任务 调用ForkJoinTask类的fork()
方法分解任务,将任务添加到队列并唤醒线程执行任务:
java 1 2 3 4 5 6 7 8 9 10 11 public final ForkJoinTask<V> fork () { Thread t; if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ((ForkJoinWorkerThread)t).workQueue.push(this ); else ForkJoinPool.common.externalPush(this ); return this ; }
3.3 添加任务 3.3.1 添加任务到本地队列 调用WorkQueue类的push()
方法,将任务添加到本地队列,唤醒线程执行任务:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 final void push (ForkJoinTask<?> task) { ForkJoinTask<?>[] a; ForkJoinPool p; int b = base, s = top, n; if ((a = array) != null ) { int m = a.length - 1 ; U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task); U.putOrderedInt(this , QTOP, s + 1 ); if ((n = s - b) <= 1 ) { if ((p = pool) != null ) p.signalWork(p.workQueues, this ); } else if (n >= m) growArray(); } }
3.3.2 添加任务到外部队列 调用ForkJoinPool类的externalPush()
方法,将任务添加到外部队列,唤醒线程执行任务:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 final void externalPush (ForkJoinTask<?> task) { WorkQueue[] ws; WorkQueue q; int m; int r = ThreadLocalRandom.getProbe(); int rs = runState; if ((ws = workQueues) != null && (m = (ws.length - 1 )) >= 0 && (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 && U.compareAndSwapInt(q, QLOCK, 0 , 1 )) { ForkJoinTask<?>[] a; int am, n, s; if ((a = q.array) != null && (am = a.length - 1 ) > (n = (s = q.top) - q.base)) { int j = ((am & s) << ASHIFT) + ABASE; U.putOrderedObject(a, j, task); U.putOrderedInt(q, QTOP, s + 1 ); U.putIntVolatile(q, QLOCK, 0 ); if (n <= 1 ) signalWork(ws, q); return ; } U.compareAndSwapInt(q, QLOCK, 1 , 0 ); } externalSubmit(task); }
调用ForkJoinPool类的externalSubmit()
方法,初始化外部队列,唤醒线程执行任务:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 private void externalSubmit (ForkJoinTask<?> task) { int r; if ((r = ThreadLocalRandom.getProbe()) == 0 ) { ThreadLocalRandom.localInit(); r = ThreadLocalRandom.getProbe(); } for (;;) { WorkQueue[] ws; WorkQueue q; int rs, m, k; boolean move = false ; if ((rs = runState) < 0 ) { tryTerminate(false , false ); throw new RejectedExecutionException (); } else if ((rs & STARTED) == 0 || ((ws = workQueues) == null || (m = ws.length - 1 ) < 0 )) { int ns = 0 ; rs = lockRunState(); try { if ((rs & STARTED) == 0 ) { U.compareAndSwapObject(this , STEALCOUNTER, null , new AtomicLong ()); int p = config & SMASK; int n = (p > 1 ) ? p - 1 : 1 ; n |= n >>> 1 ; n |= n >>> 2 ; n |= n >>> 4 ; n |= n >>> 8 ; n |= n >>> 16 ; n = (n + 1 ) << 1 ; workQueues = new WorkQueue [n]; ns = STARTED; } } finally { unlockRunState(rs, (rs & ~RSLOCK) | ns); } } else if ((q = ws[k = r & m & SQMASK]) != null ) { if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0 , 1 )) { ForkJoinTask<?>[] a = q.array; int s = q.top; boolean submitted = false ; try { if ((a != null && a.length > s + 1 - q.base) || (a = q.growArray()) != null ) { int j = (((a.length - 1 ) & s) << ASHIFT) + ABASE; U.putOrderedObject(a, j, task); U.putOrderedInt(q, QTOP, s + 1 ); submitted = true ; } } finally { U.compareAndSwapInt(q, QLOCK, 1 , 0 ); } if (submitted) { signalWork(ws, q); return ; } } move = true ; } else if (((rs = runState) & RSLOCK) == 0 ) { q = new WorkQueue (this , null ); q.hint = r; q.config = k | SHARED_QUEUE; q.scanState = INACTIVE; rs = lockRunState(); if (rs > 0 && (ws = workQueues) != null && k < ws.length && ws[k] == null ) ws[k] = q; unlockRunState(rs, rs & ~RSLOCK); } else move = true ; if (move) r = ThreadLocalRandom.advanceProbe(r); } }
3.4 唤醒线程 调用ForkJoinPool类的signalWork()
方法唤醒线程:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 final void signalWork (WorkQueue[] ws, WorkQueue q) { long c; int sp, i; WorkQueue v; Thread p; while ((c = ctl) < 0L ) { if ((sp = (int )c) == 0 ) { if ((c & ADD_WORKER) != 0L ) tryAddWorker(c); break ; } if (ws == null ) break ; if (ws.length <= (i = sp & SMASK)) break ; if ((v = ws[i]) == null ) break ; int vs = (sp + SS_SEQ) & ~INACTIVE; int d = sp - v.scanState; long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred); if (d == 0 && U.compareAndSwapLong(this , CTL, c, nc)) { v.scanState = vs; if ((p = v.parker) != null ) U.unpark(p); break ; } if (q != null && q.base == q.top) break ; } }
调用ForkJoinPool类的tryAddWorker()
方法:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 private void tryAddWorker (long c) { boolean add = false ; do { long nc = ((AC_MASK & (c + AC_UNIT)) | (TC_MASK & (c + TC_UNIT))); if (ctl == c) { int rs, stop; if ((stop = (rs = lockRunState()) & STOP) == 0 ) add = U.compareAndSwapLong(this , CTL, c, nc); unlockRunState(rs, rs & ~RSLOCK); if (stop != 0 ) break ; if (add) { createWorker(); break ; } } } while (((c = ctl) & ADD_WORKER) != 0L && (int )c == 0 ); }
调用ForkJoinPool类的createWorker()
方法:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 private boolean createWorker () { ForkJoinWorkerThreadFactory fac = factory; Throwable ex = null ; ForkJoinWorkerThread wt = null ; try { if (fac != null && (wt = fac.newThread(this )) != null ) { wt.start(); return true ; } } catch (Throwable rex) { ex = rex; } deregisterWorker(wt, ex); return false ; }
3.5 管理线程 3.5.1 创建线程 调用ForkJoinPool类的newThread()
方法创建线程:
java 1 2 3 public final ForkJoinWorkerThread newThread (ForkJoinPool pool) { return new ForkJoinWorkerThread (pool); }
调用ForkJoinWorkerThread类的构造方法创建线程:
java 1 2 3 4 5 6 protected ForkJoinWorkerThread (ForkJoinPool pool) { super ("aForkJoinWorkerThread" ); this .pool = pool; this .workQueue = pool.registerWorker(this ); }
调用ForkJoinPool类的registerWorker()
方法创建线程:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 final WorkQueue registerWorker (ForkJoinWorkerThread wt) { UncaughtExceptionHandler handler; wt.setDaemon(true ); if ((handler = ueh) != null ) wt.setUncaughtExceptionHandler(handler); WorkQueue w = new WorkQueue (this , wt); int i = 0 ; int mode = config & MODE_MASK; int rs = lockRunState(); try { WorkQueue[] ws; int n; if ((ws = workQueues) != null && (n = ws.length) > 0 ) { int s = indexSeed += SEED_INCREMENT; int m = n - 1 ; i = ((s << 1 ) | 1 ) & m; if (ws[i] != null ) { int probes = 0 ; int step = (n <= 4 ) ? 2 : ((n >>> 1 ) & EVENMASK) + 2 ; while (ws[i = (i + step) & m] != null ) { if (++probes >= n) { workQueues = ws = Arrays.copyOf(ws, n <<= 1 ); m = n - 1 ; probes = 0 ; } } } w.hint = s; w.config = i | mode; w.scanState = i; ws[i] = w; } } finally { unlockRunState(rs, rs & ~RSLOCK); } wt.setName(workerNamePrefix.concat(Integer.toString(i >>> 1 ))); return w; }
3.5.2 启动线程 调用Thread类的start()
方法创建线程,逐步调用ForkJoinWorkerThread类run()
方法:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 public void run () { if (workQueue.array == null ) { Throwable exception = null ; try { onStart(); pool.runWorker(workQueue); } catch (Throwable ex) { exception = ex; } finally { try { onTermination(exception); } catch (Throwable ex) { if (exception == null ) exception = ex; } finally { pool.deregisterWorker(this , exception); } } } }
调用ForkJoinPool类的runWorker()
方法:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 final void runWorker (WorkQueue w) { w.growArray(); int seed = w.hint; int r = (seed == 0 ) ? 1 : seed; for (ForkJoinTask<?> t;;) { if ((t = scan(w, r)) != null ) w.runTask(t); else if (!awaitWork(w, r)) break ; r ^= r << 13 ; r ^= r >>> 17 ; r ^= r << 5 ; } }
3.5.3 注销线程 调用ForkJoinPool类的deregisterWorker()
方法:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 final void deregisterWorker (ForkJoinWorkerThread wt, Throwable ex) { WorkQueue w = null ; if (wt != null && (w = wt.workQueue) != null ) { WorkQueue[] ws; int idx = w.config & SMASK; int rs = lockRunState(); if ((ws = workQueues) != null && ws.length > idx && ws[idx] == w) ws[idx] = null ; unlockRunState(rs, rs & ~RSLOCK); } long c; do {} while (!U.compareAndSwapLong (this , CTL, c = ctl, ((AC_MASK & (c - AC_UNIT)) | (TC_MASK & (c - TC_UNIT)) | (SP_MASK & c)))); if (w != null ) { w.qlock = -1 ; w.transferStealCount(this ); w.cancelAll(); } for (;;) { WorkQueue[] ws; int m, sp; if (tryTerminate(false , false ) || w == null || w.array == null || (runState & STOP) != 0 || (ws = workQueues) == null || (m = ws.length - 1 ) < 0 ) break ; if ((sp = (int )(c = ctl)) != 0 ) { if (tryRelease(c, ws[sp & m], AC_UNIT)) break ; } else if (ex != null && (c & ADD_WORKER) != 0L ) { tryAddWorker(c); break ; } else break ; } if (ex == null ) ForkJoinTask.helpExpungeStaleExceptions(); else ForkJoinTask.rethrow(ex); }
3.6 管理任务 3.6.1 窃取任务 调用ForkJoinPool类的scan()
方法窃取任务:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 private ForkJoinTask<?> scan(WorkQueue w, int r) { WorkQueue[] ws; int m; if ((ws = workQueues) != null && (m = ws.length - 1 ) > 0 && w != null ) { int ss = w.scanState; for (int origin = r & m, k = origin, oldSum = 0 , checkSum = 0 ;;) { WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t; int b, n; long c; if ((q = ws[k]) != null ) { if ((n = (b = q.base) - q.top) < 0 && (a = q.array) != null ) { long i = (((a.length - 1 ) & b) << ASHIFT) + ABASE; if ((t = ((ForkJoinTask<?>) U.getObjectVolatile(a, i))) != null && q.base == b) { if (ss >= 0 ) { if (U.compareAndSwapObject(a, i, t, null )) { q.base = b + 1 ; if (n < -1 ) signalWork(ws, q); return t; } } else if (oldSum == 0 && w.scanState < 0 ) tryRelease(c = ctl, ws[m & (int )c], AC_UNIT); } if (ss < 0 ) ss = w.scanState; r ^= r << 1 ; r ^= r >>> 3 ; r ^= r << 10 ; origin = k = r & m; oldSum = checkSum = 0 ; continue ; } checkSum += b; } if ((k = (k + 1 ) & m) == origin) { if ((ss >= 0 || (ss == (ss = w.scanState))) && oldSum == (oldSum = checkSum)) { if (ss < 0 || w.qlock < 0 ) break ; int ns = ss | INACTIVE; long nc = ((SP_MASK & ns) | (UC_MASK & ((c = ctl) - AC_UNIT))); w.stackPred = (int )c; U.putInt(w, QSCANSTATE, ns); if (U.compareAndSwapLong(this , CTL, c, nc)) ss = ns; else w.scanState = ss; } checkSum = 0 ; } } } return null ; }
3.6.2 执行任务 调用WorkQueue类的runTask()
方法执行任务:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 final void runTask (ForkJoinTask<?> task) { if (task != null ) { scanState &= ~SCANNING; (currentSteal = task).doExec(); U.putOrderedObject(this , QCURRENTSTEAL, null ); execLocalTasks(); ForkJoinWorkerThread thread = owner; if (++nsteals < 0 ) transferStealCount(pool); scanState |= SCANNING; if (thread != null ) thread.afterTopLevelExec(); } }
3.6.3 阻塞任务 调用ForkJoinPool类的awaitWork()
方法阻塞任务:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 private boolean awaitWork (WorkQueue w, int r) { if (w == null || w.qlock < 0 ) return false ; for (int pred = w.stackPred, spins = SPINS, ss;;) { if ((ss = w.scanState) >= 0 ) break ; else if (spins > 0 ) { r ^= r << 6 ; r ^= r >>> 21 ; r ^= r << 7 ; if (r >= 0 && --spins == 0 ) { WorkQueue v; WorkQueue[] ws; int s, j; AtomicLong sc; if (pred != 0 && (ws = workQueues) != null && (j = pred & SMASK) < ws.length && (v = ws[j]) != null && (v.parker == null || v.scanState >= 0 )) spins = SPINS; } } else if (w.qlock < 0 ) return false ; else if (!Thread.interrupted()) { long c, prevctl, parkTime, deadline; int ac = (int )((c = ctl) >> AC_SHIFT) + (config & SMASK); if ((ac <= 0 && tryTerminate(false , false )) || (runState & STOP) != 0 ) return false ; if (ac <= 0 && ss == (int )c) { prevctl = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & pred); int t = (short )(c >>> TC_SHIFT); if (t > 2 && U.compareAndSwapLong(this , CTL, c, prevctl)) return false ; parkTime = IDLE_TIMEOUT * ((t >= 0 ) ? 1 : 1 - t); deadline = System.nanoTime() + parkTime - TIMEOUT_SLOP; } else prevctl = parkTime = deadline = 0L ; Thread wt = Thread.currentThread(); U.putObject(wt, PARKBLOCKER, this ); w.parker = wt; if (w.scanState < 0 && ctl == c) U.park(false , parkTime); U.putOrderedObject(w, QPARKER, null ); U.putObject(wt, PARKBLOCKER, null ); if (w.scanState >= 0 ) break ; if (parkTime != 0L && ctl == c && deadline - System.nanoTime() <= 0L && U.compareAndSwapLong(this , CTL, c, prevctl)) return false ; } } return true ; }
3.7 合并任务 调用ForkJoinTask类的join()
方法获取任务执行结果:
java 1 2 3 4 5 6 public final V join () { int s; if ((s = doJoin() & DONE_MASK) != NORMAL) reportException(s); return getRawResult(); }
调用ForkJoinTask类的doJoin()
方法:
java 1 2 3 4 5 6 7 8 9 private int doJoin () { int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w; return (s = status) < 0 ? s : ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ? (w = (wt = (ForkJoinWorkerThread)t).workQueue). tryUnpush(this ) && (s = doExec()) < 0 ? s : wt.pool.awaitJoin(w, this , 0L ) : externalAwaitDone(); }
4 使用 任务类定义,因为需要返回结果,所以继承RecursiveTask,并覆写compute方法。
任务的拆分通过ForkJoinTask的fork方法执行,join方法用于等待任务执行后返回。
示例:
java 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 public class Demo { public static void main (String[] args) { SumTask sumTask = new SumTask (1 , 100 ); ForkJoinPool pool = new ForkJoinPool (); try { ForkJoinTask<Integer> task = pool.submit(sumTask); System.out.println(task.get()); } catch (Exception e) { e.printStackTrace(); } finally { pool.shutdown(); } } } class SumTask extends RecursiveTask <Integer> { private static final int THRESHOLD = 10 ; private int begin; private int end; public SumTask (int begin, int end) { this .begin = begin; this .end = end; } @Override protected Integer compute () { Integer value = 0 ; if (end - begin <= THRESHOLD) { for (int i = begin; i <= end; i++) { value += i; } } else { int middle = (begin + end) / 2 ; SumTask beginTask = new SumTask (begin, middle); SumTask endTask = new SumTask (middle + 1 , end); beginTask.fork(); endTask.fork(); value = beginTask.join() + endTask.join(); } return value; } }
结果:
log
条