ForkJoinPool源码分析

3,717 阅读11分钟

概述

ForkJoinPool是Doug Lea 在JDK 1.7后加入的,为了充分利用多核CPU的计算能力,采用分治算法,创建多个线程、多个队列,使用不同线程处理不同的队列,且处理完自己的任务后,还会窃取其他线程的任务,达到充分使用CPU的目的。ForkJoinPool有很多使用场景,特别是JDK1.8中添加的parallel流处理和异步处理类CompletableFuture等中都有用到。而且该类比较复杂,我们要战术上重视它,耐下心看且放弃一些细枝末节,先通览整个流程。战略上小看它,前面介绍类普通线程池和定时调度线程池,我们已经知道套路了(最简单的一个流程:任务提交线程池->线程池创建线程->启动线程->线程run方法中又调用任务的run方法),它也属于线程池也是大概的逻辑。

看一下ForkJoinTask流程图

ForkJoinPool使用例子

例子依然可以在github中找到

public class ForkJoinPoolTest {

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        //显然使用IntStream.parallel().sum()可以方便得到结果
        // 且parallel也是使用的ForkJoinPool,这是后话,我们本例就是测试ForkJoinTask的分解
        int[] numbers = IntStream.rangeClosed(0, 1_000_000).toArray();
        long begin = System.currentTimeMillis();
        ForkJoinTask<Integer> submit = forkJoinPool.submit(new SumTask(numbers, 0, numbers.length - 1));
        System.out.println("累加结果为:" + submit.get());
        System.out.println("运算耗时:" + (System.currentTimeMillis() - begin));
    }

    private static class SumTask extends RecursiveTask<Integer> {
        private int[] numbers;
        private int from;
        private int to;

        public SumTask(int[] numbers, int from, int to) {
            this.numbers = numbers;
            this.from = from;
            this.to = to;
        }

        @Override
        protected Integer compute() {
            if (to - from <= 2) {
                int total = 0;
                for (int i = from; i <= to; i++) {
                    total += numbers[i];
                }
                return total;
            } else {
                int middle = (from + to) / 2;
                SumTask taskLeft = new SumTask(numbers, from, middle);
                SumTask taskRight = new SumTask(numbers, middle + 1, to);
                taskLeft.fork();
                taskRight.fork();
                return taskRight.join() + taskLeft.join();
                //return taskLeft.join() + taskRight.join() ;
            }
        }
    }
}

结果就不展示了,就是计算累加的和,这里有个注意点,可以看到compute方法中,对子任务taskLeft.fork()、taskRight.fork()后,先执行taskRight.join()再加上taskLeft.join(),如果反过来写,会发现慢将近一倍的时间,为什么是这样?我们先留个疑问在这,后面揭晓。

提交任务

submit方法

public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
    //任务不允许为空
    if (task == null)
        throw new NullPointerException();
    externalPush(task);
    return task;
}
final void externalPush(ForkJoinTask<?> task) {
    WorkQueue[] ws; WorkQueue q; int m;
    //probe是和线程相关的一个值,线程私有
    int r = ThreadLocalRandom.getProbe();
    int rs = runState;
    //相当于进行一次快速入队,成功则返回,不成功externalSubmit执行完整的入队
    //当队列数组不为空且线程入队的队列不为空时,加锁入队
    if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
        (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 && //随机到某个偶数队列中
        U.compareAndSwapInt(q, QLOCK, 0, 1)) {//加锁操作,锁定workQueue
        ForkJoinTask<?>[] a; int am, n, s;
        if ((a = q.array) != null &&
            (am = a.length - 1) > (n = (s = q.top) - q.base)) {
            int j = ((am & s) << ASHIFT) + ABASE;
            U.putOrderedObject(a, j, task); //任务入队
            U.putOrderedInt(q, QTOP, s + 1);
            U.putIntVolatile(q, QLOCK, 0); //解锁操作
            if (n <= 1) //当任务数小于等于1时执行唤醒空闲线程或者创建新线程执行任务
                signalWork(ws, q);
            return;
        }
        U.compareAndSwapInt(q, QLOCK, 1, 0);
    }
    //完整版入队操作,可以看到如果某个外部线程第一次submit,肯定是到这里的(因为它得到的r是0)
    externalSubmit(task);
}

externalSubmit方法

private void externalSubmit(ForkJoinTask<?> task) {
    int r; // initialize callers probe 
    //如果线程的probe没有初始化,进行初始化
    if ((r = ThreadLocalRandom.getProbe()) == 0) {
        ThreadLocalRandom.localInit();
        r = ThreadLocalRandom.getProbe();
    }
    //这是一个死循环,所以可以保证WorkQueue[]数组的创建, 队列的创建, 任务入队
    for (;;) {
        WorkQueue[] ws; WorkQueue q; int rs, m, k;
        boolean move = false;
        if ((rs = runState) < 0) {
            tryTerminate(false, false);     // help terminate
            throw new RejectedExecutionException();
        }
        else if ((rs & STARTED) == 0 ||     // initialize WorkQueue[]数组的创建
                 ((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
            int ns = 0;
            rs = lockRunState();
            try {
                if ((rs & STARTED) == 0) {
                    U.compareAndSwapObject(this, STEALCOUNTER, null,
                                           new AtomicLong());
                    // create workQueues array with size a power of two
                    int p = config & SMASK; // ensure at least 2 slots
                    int n = (p > 1) ? p - 1 : 1;
                    n |= n >>> 1; n |= n >>> 2;  n |= n >>> 4;
                    n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
                    workQueues = new WorkQueue[n];
                    ns = STARTED;
                }
            } finally {
                unlockRunState(rs, (rs & ~RSLOCK) | ns);
            }
        }
        else if ((q = ws[k = r & m & SQMASK]) != null) { //任务入队
            if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) {
                ForkJoinTask<?>[] a = q.array;
                int s = q.top;
                boolean submitted = false; // initial submission or resizing
                try {                      // locked version of push
                    if ((a != null && a.length > s + 1 - q.base) ||
                        (a = q.growArray()) != null) {
                        int j = (((a.length - 1) & s) << ASHIFT) + ABASE;
                        U.putOrderedObject(a, j, task);
                        U.putOrderedInt(q, QTOP, s + 1);
                        submitted = true;
                    }
                } finally {
                    U.compareAndSwapInt(q, QLOCK, 1, 0);
                }
                if (submitted) { //入队成功后,唤醒或者新建一个线程,处理任务
                    signalWork(ws, q);
                    return;
                }
            }
            move = true;                   // move on failure
        }
        else if (((rs = runState) & RSLOCK) == 0) { // create new queue 队列的创建
            q = new WorkQueue(this, null);
            q.hint = r;
            q.config = k | SHARED_QUEUE;
            q.scanState = INACTIVE;
            rs = lockRunState();           // publish index
            if (rs > 0 &&  (ws = workQueues) != null &&
                k < ws.length && ws[k] == null)
                ws[k] = q;                 // else terminated
            unlockRunState(rs, rs & ~RSLOCK);
        }
        else
            move = true;                   // move if busy
        //如果队列加锁失败,说明被别的线程处理了,重新计算probe的值
        if (move)
            r = ThreadLocalRandom.advanceProbe(r);
    }
}

可以看到不管是快速入队方法,还是完整入队方法,入队成功后都会调用signalWork方法。
signalWork方法

final void signalWork(WorkQueue[] ws, WorkQueue q) {
    long c; int sp, i; WorkQueue v; Thread p;
    //循环检查:有空闲线程唤醒空闲线程,工作线程数太少,则新建空闲线程
    while ((c = ctl) < 0L) {                       // too few active
        if ((sp = (int)c) == 0) {                  // no idle workers
            if ((c & ADD_WORKER) != 0L)            // too few workers
                tryAddWorker(c); //如果工作线程太小,创建新的工作线程处理
            break;
        }
        if (ws == null)                            // unstarted/terminated
            break;
        if (ws.length <= (i = sp & SMASK))         // terminated
            break;
        if ((v = ws[i]) == null)                   // terminating
            break;
        int vs = (sp + SS_SEQ) & ~INACTIVE;        // next scanState
        int d = sp - v.scanState;                  // screen CAS
        long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
        if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
            v.scanState = vs;                      // activate v
            if ((p = v.parker) != null)
                U.unpark(p); //唤醒阻塞线程
            break;
        }
        if (q != null && q.base == q.top)          // no more work
            break;
    }
}

我们看看新建线程方法

private void tryAddWorker(long c) {
    boolean add = false;
    //也是同样的套路,先尝试CAS修改ctl值,增加工作线程数,增加成功,调用createWorker方法
    do {
        long nc = ((AC_MASK & (c + AC_UNIT)) |
                   (TC_MASK & (c + TC_UNIT)));
        if (ctl == c) {
            int rs, stop;                 // check if terminating
            if ((stop = (rs = lockRunState()) & STOP) == 0)
                add = U.compareAndSwapLong(this, CTL, c, nc);
            unlockRunState(rs, rs & ~RSLOCK);
            if (stop != 0)
                break;
            if (add) {
                createWorker(); //创建新线程
                break;
            }
        }
    } while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
}

createWorker 方法

private boolean createWorker() {
    ForkJoinWorkerThreadFactory fac = factory;
    Throwable ex = null;
    ForkJoinWorkerThread wt = null;
    try {
        //也是和ThreadPoolExecutor一样的套路
        //创建线程成功,将线程start后方法返回, 否则执行deregisterWorker进行回退操作
        if (fac != null && (wt = fac.newThread(this)) != null) {
            wt.start();
            return true;
        }
    } catch (Throwable rex) {
        ex = rex;
    }
    //注销工作线程和fac.newThread方法中的registerWorker相对
    //回退操作,会减少ctl值,移除工作线程的队列,另外如果工作线程数太少会再次调用tryAddWorker方法,尝试新建线程
    deregisterWorker(wt, ex);
    return false;
}

我们看看ForkJoinWorkerThreadFactory.newThread做了什么?
ForkJoinWorkerThreadFactory.newThread方法

public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
    return new ForkJoinWorkerThread(pool);
}
//将自己的工作队列workQueue注册到ForkJoinPool的WorkQueue[] 数组中
protected ForkJoinWorkerThread(ForkJoinPool pool) {
    // Use a placeholder until a useful name can be set in registerWorker
    super("aForkJoinWorkerThread");
    this.pool = pool;
    this.workQueue = pool.registerWorker(this);
}

final WorkQueue registerWorker(ForkJoinWorkerThread wt) {
    UncaughtExceptionHandler handler;
    wt.setDaemon(true);                           // configure thread
    if ((handler = ueh) != null)
        wt.setUncaughtExceptionHandler(handler);
    //新建一个WorkQueue对象,这个是工作线程的WorkQueue
    WorkQueue w = new WorkQueue(this, wt);
    int i = 0;                                    // assign a pool index
    int mode = config & MODE_MASK;
    int rs = lockRunState();
    try {
        WorkQueue[] ws; int n;                    // skip if no array
        if ((ws = workQueues) != null && (n = ws.length) > 0) {
            int s = indexSeed += SEED_INCREMENT;  // unlikely to collide
            int m = n - 1;
            //得到一个奇数下标
            i = ((s << 1) | 1) & m;               // odd-numbered indices
            if (ws[i] != null) {                  // collision
                int probes = 0;                   // step by approx half n
                int step = (n <= 4) ? 2 : ((n >>> 1) & EVENMASK) + 2;
                while (ws[i = (i + step) & m] != null) {
                    if (++probes >= n) {
                        workQueues = ws = Arrays.copyOf(ws, n <<= 1);
                        m = n - 1;
                        probes = 0;
                    }
                }
            }
            w.hint = s;                           // use as random seed
            w.config = i | mode;
            w.scanState = i;                      // publication fence
            //将工作线程的workWueue赋值给线程池的一个奇数下标
            ws[i] = w;
        }
    } finally {
        unlockRunState(rs, rs & ~RSLOCK);
    }
    wt.setName(workerNamePrefix.concat(Integer.toString(i >>> 1)));
    return w;
}

上面我们看到createWorker方法中,线程创建成功后,会进行thread.start,我们照旧看ForkJoinWorkerThread类的run方法吧。
ForkJoinWorkerThread.run 方法

public void run() {
    if (workQueue.array == null) { // only run once
        Throwable exception = null;
        try {
            onStart();
            pool.runWorker(workQueue);
        } catch (Throwable ex) {
            exception = ex;
        } finally {
            try {
                onTermination(exception);
            } catch (Throwable ex) {
                if (exception == null)
                    exception = ex;
            } finally {
                pool.deregisterWorker(this, exception);
            }
        }
    }
}

run方法又调用了ForkJoinPool的runWorker方法

final void runWorker(WorkQueue w) {
    //分配内存
    w.growArray();                   // allocate queue
    int seed = w.hint;               // initially holds randomization hint
    int r = (seed == 0) ? 1 : seed;  // avoid 0 for xorShift
    for (ForkJoinTask<?> t;;) {
        //进行扫描,随机窃取一个顶级任务
        if ((t = scan(w, r)) != null)
            w.runTask(t); //运行任务
        else if (!awaitWork(w, r)) //如果窃取不到任务,进行等待
            break;
        r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
    }
}
private ForkJoinTask<?> scan(WorkQueue w, int r) {
    WorkQueue[] ws; int m;
    //当线程池不为空,进行扫描
    if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
        int ss = w.scanState;                     // initially non-negative
        for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
            WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
            int b, n; long c;
            if ((q = ws[k]) != null) {//获取workQueue
                if ((n = (b = q.base) - q.top) < 0 &&
                    (a = q.array) != null) {      // non-empty
                    long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
                    if ((t = ((ForkJoinTask<?>) //获取任务
                              U.getObjectVolatile(a, i))) != null &&
                        q.base == b) {
                        if (ss >= 0) {
                            if (U.compareAndSwapObject(a, i, t, null)) {
                                q.base = b + 1; //更新base位置
                                if (n < -1)       // signal others
                                    signalWork(ws, q); //唤醒空闲线程或新建线程,帮忙处理任务
                                return t;
                            }
                        }
                        else if (oldSum == 0 &&   // try to activate
                                 w.scanState < 0)
                            tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
                    }
                    if (ss < 0)                   // refresh
                        ss = w.scanState;
                    //没扫描到,扫描其他位置    
                    r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
                    origin = k = r & m;           // move and rescan
                    oldSum = checkSum = 0;
                    continue;
                }
                checkSum += b;
            }
            //更新workQueue下标值k 继续查找
            if ((k = (k + 1) & m) == origin) {    // continue until stable
                //运行到这里说明已经扫描了全部的 workQueues,但并未扫描到任务
                if ((ss >= 0 || (ss == (ss = w.scanState))) &&
                    oldSum == (oldSum = checkSum)) {
                    if (ss < 0 || w.qlock < 0)    // already inactive
                        break;
                    //对当前WorkQueue进行inactivate 处理
                    int ns = ss | INACTIVE;       // try to inactivate
                    long nc = ((SP_MASK & ns) |
                               (UC_MASK & ((c = ctl) - AC_UNIT)));
                    w.stackPred = (int)c;         // hold prev stack top
                    U.putInt(w, QSCANSTATE, ns);
                    if (U.compareAndSwapLong(this, CTL, c, nc))
                        ss = ns;
                    else
                        w.scanState = ss;         // back out
                }
                checkSum = 0;
            }
        }
    }
    return null;
}

扫描到任务以后,会调用任务的runTask方法

final void runTask(ForkJoinTask<?> task) {
    if (task != null) {
        scanState &= ~SCANNING; // mark as busy
        //调用任务的doExec方法
        (currentSteal = task).doExec();
        U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
        execLocalTasks();
        ForkJoinWorkerThread thread = owner;
        if (++nsteals < 0)      // collect on overflow
            transferStealCount(pool);
        scanState |= SCANNING;
        if (thread != null)
            thread.afterTopLevelExec();
    }
}
final int doExec() {
    int s; boolean completed;
    if ((s = status) >= 0) {
        try {
            //调用exec方法并将返回值赋值给completed
            completed = exec();
        } catch (Throwable rex) {
            return setExceptionalCompletion(rex);
        }
        if (completed)
            s = setCompletion(NORMAL);
    }
    return s;
}

到了这里,终于快看到我们测试例子了复写的compute方法了,我们看下例子中继承的RecursiveTask类

protected final boolean exec() {
    result = compute();
    return true;
}

小结
上面我们看到线程池提交任务,放到一个workQueue数组的一个偶数下标的队列中,然后新建一个工作线程,工作线程中初始化一个workQueue放入workQueue数组奇数下标中。\

fork方法

public final ForkJoinTask<V> fork() {
    Thread t;
    //如果是ForkJoinWorkerThread 线程fork出来的,push到自己的workQueue中
    if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
        ((ForkJoinWorkerThread)t).workQueue.push(this);
    else 
        ForkJoinPool.common.externalPush(this); //否则push到common池中
    return this;
}

push 方法

final void push(ForkJoinTask<?> task) {
    ForkJoinTask<?>[] a; ForkJoinPool p;
    int b = base, s = top, n;
    if ((a = array) != null) {    // ignore if queue removed
        int m = a.length - 1;     // fenced write for task visibility
        U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task); //任务入队
        U.putOrderedInt(this, QTOP, s + 1);
        if ((n = s - b) <= 1) {
            if ((p = pool) != null)
                p.signalWork(p.workQueues, this);
        }
        else if (n >= m) //数组满了,进行扩容
            growArray();
    }
}

compute中调用子任务的fork后,就会将子任务入队了,然后taskRight.join等待子任务处理完成。我们看看join方法的逻辑。

//等待任务执行完成并返回结果
public final V join() {
    int s;
    if ((s = doJoin() & DONE_MASK) != NORMAL)
        reportException(s);
    return getRawResult();
}

private int doJoin() {
    int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
    //tryUnpush判断当前任务是栈顶任务,直接进行处理(即调子任务的compute方法),否则进入awaitJoin方法
    return (s = status) < 0 ? s :
        ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
        (w = (wt = (ForkJoinWorkerThread)t).workQueue).
        tryUnpush(this) && (s = doExec()) < 0 ? s :
        wt.pool.awaitJoin(w, this, 0L) :
        externalAwaitDone();
}

await方法

final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
    int s = 0;
    if (task != null && w != null) {
        ForkJoinTask<?> prevJoin = w.currentJoin;
        U.putOrderedObject(w, QCURRENTJOIN, task);
        CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
            (CountedCompleter<?>)task : null;
        for (;;) {
            if ((s = task.status) < 0)
                break;
            //如果是CountedCompleter任务,执行helpComplete    
            if (cc != null)
                helpComplete(w, cc, 0);
            //这里比较关键,如果队列不为空,会再执行tryRemoveAndExec    
            else if (w.base == w.top || w.tryRemoveAndExec(task))
                helpStealer(w, task);//如果队列是空或者遇到的任务都被别的线程执行过了,就偷个任务做
            if ((s = task.status) < 0)
                break;
            long ms, ns;
            if (deadline == 0L)
                ms = 0L;
            else if ((ns = deadline - System.nanoTime()) <= 0L)
                break;
            else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
                ms = 1L;
            //尝试释放一个线程或新建一个线程    
            if (tryCompensate(w)) {
                //阻塞自己
                task.internalWait(ms);
                U.getAndAddLong(this, CTL, AC_UNIT);
            }
        }
        U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
    }
    return s;
}

tryRemoveAndExec方法

final boolean tryRemoveAndExec(ForkJoinTask<?> task) {
    ForkJoinTask<?>[] a; int m, s, b, n;
    if ((a = array) != null && (m = a.length - 1) >= 0 &&
        task != null) {
        while ((n = (s = top) - (b = base)) > 0) {
            //遍历整个队列,如果队列中存在此子任务,进行调用doExec
            for (ForkJoinTask<?> t;;) {      // traverse from s to b
                long j = ((--s & m) << ASHIFT) + ABASE;
                if ((t = (ForkJoinTask<?>)U.getObject(a, j)) == null)
                    return s + 1 == top;     // shorter than expected
                else if (t == task) {
                    boolean removed = false;
                    if (s + 1 == top) {      // pop
                        if (U.compareAndSwapObject(a, j, task, null)) {
                            U.putOrderedInt(this, QTOP, s);
                            removed = true;
                        }
                    }
                    else if (base == b)      // replace with proxy
                        removed = U.compareAndSwapObject(
                            a, j, task, new EmptyTask());
                    if (removed)
                        task.doExec();
                    break;
                }
                else if (t.status < 0 && s + 1 == top) {
                    if (U.compareAndSwapObject(a, j, t, null))
                        U.putOrderedInt(this, QTOP, s);
                    break;                  // was cancelled
                }
                if (--n == 0)
                    return false;
            }
            if (task.status < 0)
                return false;
        }
    }
    return true;
}

至此整个流程就串起来了,例子中的SumTask类的compute方法执行后,会创建子任务,子任务.fork()会将任务入队,子任务.join()时,会执行子任务的compute方法。
join方法的分析完后,我们可以回答taskRight.join() + taskLeft.join()会更高效?
因为调用taskLeft.fork会将taskLeft入队,taskRight.fork会将taskRight入队,接下来如果执行taskRight.join(),taskRight这时候是栈顶任务,直接tryUnpush执行,不需要再遍历队列。