ForkJoinPool使用以及原理解读

2,968 阅读10分钟

1 前言

Java7 提供了ForkJoinPool来支持将一个任务拆分成多个“小任务”并行计算,再把多个“小任务”的结果合并成总的计算结果。

ForkJoinPool 不是为了替代 ExecutorService,而是它的补充,在某些应用场景下性能比 ExecutorService 更好。利用分而治之的思想+工作窃取算法,实现的一种线程池;最适合的是计算密集型的任务,如果存在 I/O,线程间同步,sleep() 等会造成线程长时间阻塞的情况时,最好配合使用 ManagedBlocker。

下面会从下面几个方面来介绍

  1. 使用
  2. 数据结构
  3. 流程与逻辑

2 使用

使用需要下面步骤

  1. 定义任务

    • 普通任务:runnable接口,Callable接口等实现类
    • ForkJoinTask子类:CountedCompleter、RecursiveAction、RecursiveTask是其进一步实现封装的抽象类;用户选取上述类自行实现即可
  2. 提交任务

    • 普通任务提交,必须使用ForkJoinPool
    • ForkJoinTask类型任务,可以fork来处理
  3. 获取结果

    • ForkJoinTask任务句柄,join方法处理
    • 线程池invoke方法,提交并执行

普通任务这里就不给示例了,和ThreadPoolExecutor使用没有啥区别;下面举个 ForkJoinTask类型任务例子

2.1 定义任务

class Task(private val num : Int) : RecursiveTask<Long>() {
    override fun compute(): Long {
        if (num < 2) return 1L
        val t1 = Task(num - 1)
        val t2 = Task(num - 2)
        t1.fork()
        t2.fork()
        return t1.join() + t2.join()
    }
}

RecursiveTask是有计算结果的任务,RecursiveAction无计算结果的任务;CountedCompleter后面会单独介绍

2.2 任务提交、结果获取

    val task = Task(20)
    task.fork()
    print(task.join())

又或者线程池提交

    print(ForkJoinPool(10).submit(Task(20)).join())
    print(ForkJoinPool(10).invoke(Task(20)))
    print(ForkJoinPool.commonPool().invoke(Task(20)))

ForkJoinPool.commonPool()为通用的、已提供的ForkJoinPool实例;这里要注意join方法为阻塞方法;另外也要注意,fork方法虽然是提交任务,但是任务有可能被窃取执行,所以,join有可能立即获取结果;所以需要在合理的地方进行结果获取;也可获取提交任务句柄,在需要的地方进行获取值

使用是不是很简单,但是我说上面任务返回结果

    return t1.join() + t2.join()

替换为

    return t2.join() + t1.join()

执行效率会略高一些,你会信吗?这个和join方法内的逻辑有关,如果任务最后一个加入,则可以优先执行,而不必等待

2.3 CountedCompleter任务

复杂且使用比较灵活;它可以通过内部逻辑把自己转化为RecursiveTask、RecursiveAction任务,也可以更灵活的使用,并且最大的不同就是其只有一个任务需要join操作且任务间并不阻塞线程池内部的调用,任务间的联系需要通过相应回调来触发,其通过完成回调方法合并其依赖的结果;内部增加了如下两个成员变量

    final CountedCompleter<?> completer;
    volatile int pending;

completer:依赖当前任务的节点;其像链表,但又不是,说是树可能更合适;最开始的那个任务,是树根节点,其依赖的为其孩子节点

pending: 当前节点依赖的节点个数,也可以说其孩子节点的个数;类中提供了一些列的方法操作,不介绍了;其内部方法调用时,都是先于0比较,然后,才会减少1,所以内部方法进行结束任务时,这个个数+1才是依赖的数目

一般情况下,我们不需直接对pending直接操作,可以使用其已经提供的一些方法进行操作,进而达到效果;方法有下面几个:

  • tryComplete:当前点为出发点,向依赖其节点进行循环处理,遇到以下情况会结束

    1. pending为0且依赖其的节点为空:pending为0时,回调onCompletion完成处理方法;若依赖其节点为空,则调用quietlyComplete方法设置执行状态为完成
    2. 处理当前节点pending值-1成功
  • propagateCompletion方法,和tryComplete方法相比,无onCompletion方法回调调用,也即对于每个中间任务无需关注

  • quietlyCompleteRoot : 依照指针域去寻找根依赖节点,并为其设置正常结束状态;比较暴力的结束任务状态,这种适合于找到某一个结果就停止

onCompletion回调方法

这个方法是仅仅通知当前任务所有依赖已经完成,用于任务合并操作,但却在此方法中仅仅知道最后一个完成的依赖任务;

为何CountedCompleter要设置正常结束状态,这时由于ForkJoinTask在执行方法的逻辑

    final int doExec() {
        int s; boolean completed;
        if ((s = status) >= 0) {
            try {
                completed = exec();
            } catch (Throwable rex) {
                return setExceptionalCompletion(rex);
            }
            if (completed)
                s = setCompletion(NORMAL);
        }
        return s;
    }

也即是,现有ForkJoinTask的子类exec方法,均是返回true,而只有CountedCompleter返回false,所以其需要设置正常结束状态,任务才会被结算成执行完毕,在任务fork等调用时,才会结束阻塞;如果你只是往里面添加一个任务这个则不处理也没有关系

类似RecursiveAction的效果

class Task(private val num : Int,private val end : Int, completer: Task? = null) : CountedCompleter<Void>(completer) {
    override fun compute() {
        if (end == num) {
            if (end % 2 == 0) println("odd $end")
            propagateCompletion()
            return
        }
        addToPendingCount(1)
        val middle = (num + end) / 2
        Task(num, middle, this).fork()
        Task(middle + 1, end,this).fork()
    }
}

类似RecursiveTask的效果

class Task(val num : Int,val end : Int, completer: Task? = null) : CountedCompleter<Int>(completer) {
    @Volatile public var mResult = 0
    private var t1 : Task? = null
    private var t2 : Task? = null
    override fun compute() {
        if (end == num) {
            mResult = end
            tryComplete()
            return
        }
        addToPendingCount(1)
        val middle = (num + end) / 2
        t1 = Task(num, middle, this).fork() as Task
        t2 = Task(middle + 1, end,this).fork() as Task
    }

    override fun onCompletion(caller: CountedCompleter<*>?) {
        if (this != caller && caller is Task) {
            mResult = (t1?.mResult ?: 0) + (t2?.mResult ?: 0)
        }
    }

    override fun getRawResult(): Int {
        return mResult
    }

    override fun setRawResult(t: Int?) {
        mResult = t ?: 0
    }
}

如果不通过根任务的join等方法获取结果,而是其它数据交流的办法(Rxjava 中发射、LiveData等),则可以不重写get/setRawResult方法

某个特殊结果寻找

class Task(val num : Int,val end : Int, completer: Task? = null) : CountedCompleter<Int>(completer) {
    @Volatile public var mResult = 0
    override fun compute() {
        if (end % 7 == 0 && end % 5 == 0) {
            (root as Task).mResult = end
            quietlyCompleteRoot()
            return
        } else if (num == end) {
            return
        }
        addToPendingCount(1)
        val middle = (num + end) / 2
        Task(num, middle, this).fork()
        Task(middle + 1, end,this).fork()
    }

    override fun getRawResult(): Int {
        return mResult
    }

    override fun setRawResult(t: Int?) {
        mResult = t ?: 0
    }
}

可能还有其它场景,但是这些场景的处理都是依据pending值和其引用来确定是否设置结束状态;

  • 原子操作设置值:addToPendingCount、compareAndSetPendingCount等方法
  • 利用设置状态方法来处理:propagateCompletion、tryComplete、quietlyCompleteRoot等

3 原理实现

ForkJoinPool线程池,其执行任务的线程对象是ForkJoinWorkerThread子类,任务均被包装为ForkJoinTask的子类

3.1 ForkJoinWorkerThread类

Thread子类,其中主要内容有:线程队列创建、销毁、执行

3.1.1 线程队列创建

在构造器中通过ForkJoinPool.registerWorker方法为当前线程关联队列,队列位置为线程池队列数组的奇数位置

3.1.2 线程的销毁

通过ForkJoinPool.deregisterWorker方法进行销毁

3.1.4 线程的运行

run方法内为其主要逻辑,不贴代码了;需要在其线程队列建立后,持有数据还未申请空间之前进行线程执行,否则不做任何处理

回调方法onStart,表示线程开始执行;通过ForkJoinPool.runWorker方法来执行任务;onTermination回调方法接收异常处理;

3.2 ForkJoinTask类

抽象类,实现了Future、Serializable接口;其主要内容:任务异常收集、fork-join执行流程(join也可以是invoke、get等操作,但这里就依据join来讲解)

task有以下几种状态

    volatile int status;
    static final int DONE_MASK   = 0xf0000000;
    static final int NORMAL      = 0xf0000000;
    static final int CANCELLED   = 0xc0000000;
    static final int EXCEPTIONAL = 0x80000000;
    static final int SIGNAL      = 0x00010000;
    static final int SMASK       = 0x0000ffff;
  • NORMAL:结束状态,正常结束,负数
  • CANCELLED:结束状态,用户取消,负数
  • EXCEPTIONAL:结束状态,执行异常,负数
  • SIGNAL:等待通知执行状态,正数
  • 0 : 起始状态

3.2.1 异常收集

异常数据收集,是根据弱引用机制来处理;弱引用任务节点结构如下:

static final class ExceptionNode extends WeakReference<ForkJoinTask<?>> {
        final Throwable ex;
        ExceptionNode next;
        final long thrower; 
        final int hashCode; 
        ExceptionNode(ForkJoinTask<?> task, Throwable ex, ExceptionNode next,
                      ReferenceQueue<Object> exceptionTableRefQueue) {
            super(task, exceptionTableRefQueue);
            this.ex = ex; // 原始异常
            this.next = next; // 相同hash的节点指针域
            this.thrower = Thread.currentThread().getId(); // 线程标识
            this.hashCode = System.identityHashCode(task); // 与对象地址相对应的hash
        }
    }

弱引用节点相关数据结构

    private static final ExceptionNode[] exceptionTable; // 异常数据
    private static final ReentrantLock exceptionTableLock; // 异常节点锁
    private static final ReferenceQueue<Object> exceptionTableRefQueue; // 弱引用回收队列

采用的数组存储,并利用hash进行映射,单链表进行冲突解决;并在需要处理异常时,实时去除已经销毁的task节点异常;常用操作如下:

  • 记录异常:recordExceptionalCompletion方法,在任务未完成的情况才会记录
  • 清除当前节点异常:clearExceptionalCompletion方法
  • 获取异常:getThrowableException,非当前线程异常,需要进行包装转换
  • 清理无效task相关联异常:expungeStaleExceptions静态方法,清除掉回收队列中task所有相关异常节点

3.2.2 fork-join逻辑

fork方法用于向队列中保存任务;偶数任务队列中未依赖于线程,奇数队列为线程私有

   public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }
  1. 当前在ForkJoinWorkerThread线程中执行,则调用workQueue.push方法存入队列
  2. 放入线程池中队列数组中偶数位置的队列中

join方法用于阻塞获取结果

    public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }
    
    private int doJoin() {
        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
        return (s = status) < 0 ? s :
            ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
            (w = (wt = (ForkJoinWorkerThread)t).workQueue).
            tryUnpush(this) && (s = doExec()) < 0 ? s :
            wt.pool.awaitJoin(w, this, 0L) :
            externalAwaitDone();
    }

同样需要根据线程类型判断

  1. 状态小于0,也即任务已结束,则直接返回,如果是异常则会抛出异常
  2. 未执行时,不是ForkJoinWorkerThread线程内执行,以当前任务实例为锁对象,进行等待(更具体的逻辑在externalAwaitDone方法内分析)
  3. 未执行时,ForkJoinWorkerThread线程内执行;如果任务为当前线程队列的顶部(也就是最后一个提交的)且执行后处于结束状态,则返回
  4. 线程池内awaitJoin进行等待(其时可能存在窃取其它任务队列进行执行)

externalAwaitDone方法

首先尝试执行,如果满足下面条件,则会执行doExec方法(调用exec()方法进行具体执行)

  1. CountedCompleter任务类型,则common线程池方法externalHelpComplete返回true
  2. 其它任务类型,common线程池tryExternalUnpush方法返回true

如果未执行,则通过staus原子操作+synchronized锁,进行等待

3.2 ForkJoinPool类

这里主要有一些常量的意义、队列结构、执行流程、窃取线程思路;

3.2.1 状态成员变量

    volatile long ctl;                
    volatile int runState;
    final int config;

ct1,64位,分为4段,每相邻16位为一段

  • 高16位,正在处理任务的线程个数;初始化为并行数的负值(构造器中线程的并行线程数,一般来说为能创建的最大线程数)
  • 次高16位,线程总数,初始化为并行数的负值
  • 次低16位,线程状态,小于0时需要添加新的线程,或者说48位的位置为1时,需要添加线程
  • 低16位,空闲线程对应的任务队列在队列数组的索引位置

runState,有下面几种状态,默认态为0

    private static final int  STARTED    = 1;
    private static final int  STOP       = 1 << 1;
    private static final int  TERMINATED = 1 << 2;
    private static final int  SHUTDOWN   = 1 << 31;

config:低16位代表 并行度(parallelism),高16位:队列模式,默认是后进先出

3.2.3 线程队列

volatile WorkQueue[] workQueues

数组结构,分为线程队列和非线程队列,随机寻找位置进行创建与查找;达到WorkQueue均匀处理,以减少WorkQueue同步开销

        volatile int scanState;    // 负数:inactive, 非负数:active, 其中奇数代表scanning
        int stackPred;             // sp = (int)ctl, 前一个队列栈的标示信息,包含版本号、是否激活、以及队列索引
        int nsteals;               // 窃取的任务数
        int hint;                  // 一个随机数,用来帮助任务窃取,在 helpXXXX()的方法中会用到
        int config;                // 配置:二进制的低16位代表 在 queue[] 中的索引,高16位:mode可选FIFO_QUEUE(1 << 16)和LIFO_QUEUE(1 << 31),默认是LIFO_QUEUE
        volatile int qlock;        // 锁定标示位:1: locked, < 0: terminate; else 0
        volatile int base;         // index of next slot for poll
        int top;                   // index of next slot for push
        ForkJoinTask<?>[] array;   // 任务列表

WorkQueue中数据结构主体:任务数组、任务队列头部、尾部;以及线程操作同步标志,使用原子操作+volatile来实现,-1表示不允许操作了、0表示可以操作、1表示正常操作

因此其方法可以分为线程安全方法、非线程安全方法;线程安全方法用于窃取,非线程安全方法用于线程内任务执行

  • push方法:队列尾部加入数据,非线程安全
  • growArray方法:数组扩容,2被扩容,非线程安全
  • pop方法:从尾部取出数据,原子操作保证线程安全,但不保证成功
  • pollAt方法:从头部取出数据,原子操作保证线程安全,但不保证成功
  • poll: 从头部取出数据,原子操作+自旋,保证线程安全
  • nextLocalTask:根据策略,进行取出数据(根据congfig来进行处理),线程安全
  • peek:根据出队模式返回队头或者队尾元素,但不取出,非线程安全
  • tryUnpush:尝试判断是否为队尾任务,线程安全,但结果不一定准确
  • sharedPush:共享队列(偶数位置的WorkQueue实例),队尾增加数据方法,使用qlock原子操作来实现线程安全,但不保证结果准确,其中队列扩容通过growAndSharedPush方法处理并增加数据
  • trySharedUnpush:判断任务是否处于队尾,原子操作保证线程安全,不保证结果准确
  • cancelAll: 取消所有任务
  • localPopAndExec:从队尾开始执行任务,原子操作+自旋来保证线程安全,存在线程竞争时,则退出,不进行处理
  • localPollAndExec:从队头开始执行任务,原子操作+自旋来保证线程安全,存在线程竞争时,则退出,不进行处理
  • runTask:执行窃取任务,并依据出队某事调用localPopAndExec或者localPollAndExec来继续本线程队列任务处理
  • tryRemoveAndExec:自旋+原子操作,尽可能执行线程私有队列中的任务;非队尾数据,原子操作为EmptyTask
  • popCC:取出队尾的CountedCompleter任务,原子操作+自旋保证线程安全
  • pollAndExecCC:取出队头CountedCompleter任务,并执行,原子操作+自旋保证线程安全

3.2.4 调用流程

主要有下面三个流程提交任务流程、线程执行流程、获取结果流程

提交任务

从类的角度来看

  1. 线程池提交任务
  2. ForkJoinTask类的fork

从功能角度来看

  1. Fork线程内部提交任务
  2. 非Fork线程提交任务,第一个任务肯定是这种方式

外部提交任务

forkJoinThread外部提交任务.jpg

内部提交任务,直接调用线程私有WorkQueue对象,push方法加入队尾

线程执行

ForkJoinPool执行线程池.png

join获取任务结果

join结果获取.png

从上面三个流程能够大致知道处理的流程,但是偷取的具体的逻辑还是不清楚的;有下面方法需要仔细研读,掌握思想精髓

  1. scan方法:fork线程窃取任务,fork线程的第一个任务都是窃取而来
  2. awaitJoin方法:线程池内等待,不可被处理时,自己偷自己的任务
  3. CountedCompleter任务与其它任务处理的区别,CountedCompleter任务不会相互阻塞
  4. 锁等待机制:图中可能存在错误;闲置线程,才会线程暂停或者启用,任务的暂停等待则是Object的wait方法,且其执行结束后会notifyAll唤醒所有
  5. 位运算运用,以及各种状态之间的判断处理,以及这些对性能的一些追求

具体点方法分析,我也有部分点不是很明白,但是如果不写相关偷窃算法或者一些转移思想,有部分不清楚也是可以的

技术变化都很快,但基础技术、理论知识永远都是那些;作者希望在余后的生活中,对常用技术点进行基础知识分享;如果你觉得文章写的不错,请给与关注和点赞;如果文章存在错误,也请多多指教!