压榨CPU为乐——Future&ForkJoin原理分析

148 阅读14分钟

Future

在日常开发中,多多少少都会使用线程池去异步处理相关业务,比如像发送短信,邮件发送,消息推送等等。大多数场景下都是使用execute()方法去执行,但是有些时候需要进行远程调用,处理大量的数据或者其他异步执行并且需要获取异步执行的结果,那么这个时候就可以用到ExecutorService提供的submit()方法。

<T> Future<T> submit(Callable<T> task);

调用submit方法会返回一个Future,Future是一个接口,在接口中提供了5种方法,分别为获取执行结果方法get(),获取当前任务是否执行完成isDone(),判断当前任务取消状态isCancelled()等。

public interface Future<V> {

    /**
    * 取消任务 
    */
    boolean cancel(boolean mayInterruptIfRunning);

    /**
    * 判断当前任务是否被取消
    */
    boolean isCancelled();

      /**
      * 判断当前任务是否完成
      */
    boolean isDone();

   /**
   * 获取任务返回结果
   */
    V get() throws InterruptedException, ExecutionException;

   
   /**
   * 获取任务返回结果
   * 如果在指定时间范围内没获取到,则超时
   */
    V get(long timeout, TimeUnit unit)
        throws InterruptedException, ExecutionException, TimeoutException;
}

下面通过例子来进行原理以及源码分析,代码很简单创建一个线程池,通过线程池执行submit。其中一个线程内部执行缓慢,另一个线程内部什么都不做,然后查看执行效果。

public static void main(String[] args) throws ExecutionException, InterruptedException {
    ExecutorService executorService = Executors.newFixedThreadPool(5);

    Future<Integer> getSum = executorService.submit(() -> {
        TimeUnit.SECONDS.sleep(2);
        int sum = 0;
        for (int i = 0; i < 1000000; i++) {
            sum+=i;
        }
        return sum;
    });

    Future<Integer> future = executorService.submit(() -> { return 1; });
    
    TimeUnit.SECONDS.sleep(1);
    
    System.out.println("future====> "+future.isDone());
    System.out.println("future====> "+future.get());
    System.out.println("getSum====> "+getSum.isDone());
    System.out.println("getSum====> "+getSum.get());

    executorService.shutdown();
}

image.png

通过运行结果发现future任务已经完成了并且成功获取到了执行结果,但是getSum的任务状态是未完成,怎么也获取到任务执行的结果了呢?这是因为在调用get方法时,如果当前任务没有执行完成,那么就会阻塞线程,直到执行完成才会唤起线程继续执行。


 //初始化任务
 private static final int NEW          = 0;
 //任务已经完成,分为正常完成和异常完成,是一个中间状态
 private static final int COMPLETING   = 1;
 //正常完成,是最终状态
 private static final int NORMAL       = 2;
 //任务已经完成,但是异常完成,是最终状态
 private static final int EXCEPTIONAL  = 3;
 //任务还没开始执行就被取消了,是最终状态
 private static final int CANCELLED    = 4;
 //任务还没开始执行就被打断(这里的中断指的是Thread的interrupt方法),是中间状态
 private static final int INTERRUPTING = 5;
 //任务还没开始执行就被打断(比如操作系统kill信号),是最终状态
 private static final int INTERRUPTED  = 6;

//任务本身
private Callable<V> callable;
//任务结果
private Object outcome;  
//运行的线程 
private volatile Thread runner;
//等待节点,是一个单项列表结构
private volatile WaitNode waiters;


public V get() throws InterruptedException, ExecutionException {
    int s = state;
    if (s <= COMPLETING) // 如果当前任务为初始化状态,则等待任务执行完成
        s = awaitDone(false, 0L);
    return report(s); // 返回执行结果
}


/**
* 等待链表,当future任务尚未完成时,提交任务的线程将被阻塞并插入到列表中,直到run方法执行完成在唤醒
*/
static final class WaitNode {
    volatile Thread thread; //当前线程
    volatile WaitNode next; // 下一个节点
    WaitNode() { thread = Thread.currentThread(); }
}

/**
* 初始化节点,插入链表并阻塞线程
*/
private int awaitDone(boolean timed, long nanos)
    throws InterruptedException {
    final long deadline = timed ? System.nanoTime() + nanos : 0L;
    WaitNode q = null;
    boolean queued = false;
    for (;;) {
        // 线程是否中断,如果中断移除该节点
        if (Thread.interrupted()) {
            removeWaiter(q);
            throw new InterruptedException();
        }

        int s = state;
        if (s > COMPLETING) { // 如果当前节点状态是完成状态
            if (q != null)
                q.thread = null; //置空节点,等待GC
            return s;
        }
        else if (s == COMPLETING) // cannot time out yet
            Thread.yield();
        else if (q == null) // 初始化一个新的等待节点
            q = new WaitNode();
        else if (!queued) // 将等待节点插入列表
            queued = UNSAFE.compareAndSwapObject(this, waitersOffset,
                                                 q.next = waiters, q);
        else if (timed) { // 如果设置了超时时间
            nanos = deadline - System.nanoTime();
            if (nanos <= 0L) {
                removeWaiter(q);
                return state;
            }
            LockSupport.parkNanos(this, nanos);
        }
        else
            LockSupport.park(this); //阻塞当前线程
    }
}
/**
* 任务执行的主方法
*
*/
public void run() {
   ......
    try {
        Callable<V> c = callable;
        if (c != null && state == NEW) {
            V result;
            boolean ran;
            try {
                result = c.call();
                ran = true;
            } catch (Throwable ex) {
                result = null;
                ran = false;
                setException(ex);
            }
            if (ran) // 如果任务执行成功,则设置结果,并唤起线程
                set(result);
        }
    } finally {
      .......
    }
}


/*
* 更改state,设置任务执行结果,唤醒阻塞线程
*
*/
protected void set(V v) {
    if (UNSAFE.compareAndSwapInt(this, stateOffset, NEW, COMPLETING)) {
        outcome = v; //设置任务执行结果
        UNSAFE.putOrderedInt(this, stateOffset, NORMAL); // final state
        finishCompletion(); //唤醒阻塞线程
    }
}


/**
* 返回任务执行结果
*/
private V report(int s) throws ExecutionException {
    Object x = outcome;
    if (s == NORMAL) //如果任务正常执行完成
        return (V)x; // 返回结果
    if (s >= CANCELLED)
        throw new CancellationException();
    throw new ExecutionException((Throwable)x);
}

ForkJoin

ForkJoin是JDK7开始提供的用于并行执行任务的框架,其采用分而治之的思想,将一个大任务拆分成若干个小任务,且每个小任务之间不存在任何关联,最终将小任务的执行结果进行汇总,从而得到最终的一个结果。这种方式可以充分利用CPU多核的特点,提高程序的运算速度以及CPU利用率。

image.png

ForkJoinTask为我们提供了两个子类,分别为RecursiveTask,RecursiveAction和CountedCompleter,其中RecursiveTask适合有返回结果的任务,RecursiveAction适合没有返回结果的任务,CountedCompleter无返回值任务,完成任务后可以触发回调。可以结合自己的业务场景,选择合适的类型并且自行实现分治逻辑。需要注意的是ForkJoin仅适用与大数据量计算型业务,不适合IO频繁的业务,比如在你想在ForkJoin里面去操作数据库或者缓存之类的业务,那么就不推荐使用。举个简单的例子,计算十亿内数据的总和,分别用单线程和ForkJoin做对别,看哪个方法执行效率高。

class ArraySum extends RecursiveTask<Long> {

    public long threshold = 10000; // 每个子任务的粒度

    public long startIdx; // 子任务的开始值

    public long endIdx; //子任务结束值

    public ArraySum(long startIdx, long endIdx) {
        this.endIdx = endIdx;
        this.startIdx = startIdx;
    }

    @Override
    protected Long compute() {

        if ((endIdx - startIdx) <= threshold) { // 如果当前范围达不到可分配任务力度,自行计算
            long sum = 0;
            for (long i = startIdx; i <= endIdx; i++) {
                sum+=i;
            }
            return sum;
        } else {
            // 采用类似于递归的思想,取中值,并且进行任务拆分
            long mid = (endIdx + startIdx) / 2;
           
            ArraySum sumLeft = new ArraySum(startIdx, mid); // 以中值为目标 取左边。作为一个新任务

            ArraySum sumRight = new ArraySum(mid + 1, endIdx); // 以中值为目标 取右边。作为一个新任务

            sumLeft.fork(); // 执行任务,如果没有达到任务力度,会继续拆分

            sumRight.fork(); // 执行任务,如果没有达到任务力度,会继续拆分

            return sumLeft.join() + sumRight.join(); //计算任务执行结果
        }
    }
public static void main(String[] args) throws ExecutionException, InterruptedException {

    long min = 1;
    long max= 1000000000L;
    long sum=0;
    long start = System.currentTimeMillis();
    for (long i = min; i <= max; i++) {
        sum = sum + i;
    }
    System.out.println(sum + "   单线程用时:"+(System.currentTimeMillis()-start));


    ForkJoinPool forkJoinPool = new ForkJoinPool();
    start = System.currentTimeMillis();
    ForkJoinTask<Long> task = forkJoinPool.submit(new ArraySum(min, max));

    System.out.println(task.get()+ "   ForkJoin用时:"+(System.currentTimeMillis()-start));
}

image.png 通过运行结果可以看出来,ForkJoin的效率要比单线程的效率高出一倍。默认情况下,ForkJoin的工作线程数为运行机器的CPU核数Runtime.getRuntime().availableProcessors(),所以会将子任务划分到每个CPU上去执行,这样无论是CPU利用率还是程序运行效率都会高很多。

image.png

那ForkJoin内部到底是怎么运行的呢?任务拆分后存储在哪里呢?在ForkJoinPool内部维护了一个WorkQueue队列数组,每个线程都会绑定一个自己的WorkQueue,WorkQueue是一个双端队列,为什么要设计成双端队列呢,这是因为方便自身线程可以从尾部取元素的同时,其他空闲线程可以从头部实现工作元素窃取,减少竞争。拆分出来的任务将会存储到这个工作队列中,为了区分是ForkJoin内部线程还是外部线程,WorkQueue数组采用奇偶下标的方式来区分内部线程和外部线程。如果数组下标为奇数,则存储外部线程任务,如果是偶数则存储内部线程任务。当ForkJoin工作线程从队列中取出任务元素时,会判断是否达到了任务的拆分粒度,如果没有达到那么会将该任务重新拆分并如队,否则将会执行该任务。如果该工作队列元素为空,那么此时该线程就会挨个扫描任务队列,如果发现其中有队列元素不为空,那么就会从队列头部窃取(work-stealing)该任务到自己的工作队列并执行,直到所有任务全部完成。

工作窃取是指,当一个ForkJoin内部线程空闲时,会扫描其他工作队列是否存在未完成的任务。如果有发现某个工作队列有未完成的任务,那么就会从队列头部,将任务窃取到自己的工作队列中,

image.png

image.png

下面就以RecursiveTask为例,通过源码来梳理整个的工作流程。

初始化

ForkJoinPool中有四个构造参数,分别为

int parallelism :并行度(工作线程数,默认为CPU核数)

ForkJoinWorkerThreadFactory factory: 创作工作线程工厂

UncaughtExceptionHandler handler:异常处理器,当执行的任务出现异常,并从任务中被抛出时,就会被handler捕获

boolean asyncMode:任务队列出队模式 true:先进先出,false:后进先出。每个队列都绑定了独立的线程

image.png

public ForkJoinPool(int parallelism,
                    ForkJoinWorkerThreadFactory factory,
                    UncaughtExceptionHandler handler,
                    boolean asyncMode) {
    this(checkParallelism(parallelism),
         checkFactory(factory),
         handler,
         asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
         "ForkJoinPool-" + nextPoolId() + "-worker-");
    checkPermission();
}

submit

在调用submit方法时,先回判断提交的任务是否为空,如果为空,那么就会抛出异常,否则调用externalPush方法。这里的源码写的很是抽象,可阅读性不是很好,只能理解大概的意思。

public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
    if (task == null)
        throw new NullPointerException();
    externalPush(task); //将任务添加到workQueue
    return task;
}
final void externalPush(ForkJoinTask<?> task) {
    WorkQueue[] ws; WorkQueue q; int m;
    int r = ThreadLocalRandom.getProbe(); //获取随机数,用于下面计算下标
    int rs = runState;
    // 如果workqueue已经初始化,并且不为空,那么将以CAS的方式加解锁,并put元素
    if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
        (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
        U.compareAndSwapInt(q, QLOCK, 0, 1)) { //队列加锁
        ForkJoinTask<?>[] a; int am, n, s;
        if ((a = q.array) != null &&
            (am = a.length - 1) > (n = (s = q.top) - q.base)) {
            int j = ((am & s) << ASHIFT) + ABASE;
            U.putOrderedObject(a, j, task); //将任务加入队列
            U.putOrderedInt(q, QTOP, s + 1);
            U.putIntVolatile(q, QLOCK, 0); // 队列解锁
            if (n <= 1)
                signalWork(ws, q); //唤醒阻塞的线程,并扫描其他队列试图窃取任务
            return;
        }
        U.compareAndSwapInt(q, QLOCK, 1, 0);
    }
    externalSubmit(task); // 当首次添加任务时,调用该方法,从头开始创建workqueue
}

signalWork

在signalWork方法里主要做了两件事情,一个是会判断任务粒度是否满足条件,如果不满足条件,那么将会继续拆分任务。另一个是唤起空闲阻塞线程。

final void signalWork(WorkQueue[] ws, WorkQueue q) {
        long c; int sp, i; WorkQueue v; Thread p;
        while ((c = ctl) < 0L) {                       
            if ((sp = (int)c) == 0) {                
                if ((c & ADD_WORKER) != 0L)  //判断当前任务是否达到拆分任务的粒度,如果没有则继续创建线程拆分          
                    
                    tryAddWorker(c); //添加 Worker,创建线程
                break;
            }
            if (ws == null)      //未开始或者已停止,直接跳出                       
                break;
            if (ws.length <= (i = sp & SMASK))         
                break;
            if ((v = ws[i]) == null)                  
                break;
              //程序执行到这里,说明有空闲线程,计算下一个scanState,增加了版本号,并且调整为 active 状态
            int vs = (sp + SS_SEQ) & ~INACTIVE;       
            int d = sp - v.scanState;                  
            //计算下一个ctl的值,活动线程数 AC + 1,通过stackPred取得前一个WorkQueue的索引,重新设置回sp,行程最终的ctl值
              long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
              //更新 ctl 的值
            if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
                v.scanState = vs;                     
                  //如果有线程阻塞,则调用unpark唤醒即可 
                  if ((p = v.parker) != null)
                    U.unpark(p);
                break;
            }
              //没有任务,直接跳出
            if (q != null && q.base == q.top)          
                break;
        }

/**
* 通过线程工厂创建线程,并启动
*/
private boolean createWorker() {
    ForkJoinWorkerThreadFactory fac = factory;
    Throwable ex = null;
    ForkJoinWorkerThread wt = null;
    try {
        if (fac != null && (wt = fac.newThread(this)) != null) { // 线程工厂不为空,并且创建线程成功
            wt.start(); // 启动线程
            return true;
        }
    } catch (Throwable rex) {
        ex = rex;
    }
    deregisterWorker(wt, ex); //如果创建线程失败,就要逆向注销线程
    return false;
}
protected ForkJoinWorkerThread(ForkJoinPool pool) {
  super("aForkJoinWorkerThread");
  this.pool = pool;
  this.workQueue = pool.registerWorker(this); // 线程与队列绑定
  }

fork

在fork方面里面其实逻辑相对简单一些,他会判断当前的线程是否为ForkJoinPool的内部线程,如果是内部线程,那么将调用push方法,将任务存入当前线程绑定的workqueue中。否则重新调用externalPush方法。

public final ForkJoinTask<V> fork() {
  Thread t;
  if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) //判断当前线程是否为内部线程
    ((ForkJoinWorkerThread)t).workQueue.push(this); //入栈
  else
    ForkJoinPool.common.externalPush(this); //重新调用externalPush,在上面有源码
  return this;
}

join

join方法中主要判断当前任务是否完成,其核心方法是doJoin()。但是写这个代码的人真的想吐槽一下,代码可读性实在太不友好了。

public final V join() {
    int s;
    if ((s = doJoin() & DONE_MASK) != NORMAL)
        reportException(s);
    return getRawResult(); // 返回执行结果
}


private int doJoin() {
    int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
    
     /**
         * (s = status) < 0 判断任务是否已经完成,完成直接返回
         * 任务未完成:
         * 1)判断线程是否为内部线程,如果是内部线程,则调用tryUnpush任务出队然后消费任务doExec并返回结果。
         *    如果这个任务偷走了,执行awaitJoin进行自旋,如果任务状态是完成就退出,否则继续尝试出队,
         *    直到任务完成或超时为止;
         * 2)如果线程不是内部线程,执行externalAwaitDone进行出队消费
         */
    return (s = status) < 0 ? s :
        ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
        (w = (wt = (ForkJoinWorkerThread)t).workQueue).
        tryUnpush(this) && (s = doExec()) < 0 ? s :
        wt.pool.awaitJoin(w, this, 0L) :
        externalAwaitDone();
}

final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
    int s = 0;
    if (task != null && w != null) {
        ForkJoinTask<?> prevJoin = w.currentJoin;
        U.putOrderedObject(w, QCURRENTJOIN, task);
        CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
            (CountedCompleter<?>)task : null;
        for (;;) {
            if ((s = task.status) < 0) //判断任务是否完成,完成则终止自旋
                break;
            if (cc != null) 
                //CountedCompleter类型的任务调用helpComplete()方法
                helpComplete(w, cc, 0);
            else if (w.base == w.top || w.tryRemoveAndExec(task))
               //任务被别的线程偷走了,帮助偷取者执行该任务
                helpStealer(w, task);
            if ((s = task.status) < 0)
                break;  //判断任务是否完成,完成则终止自旋
            long ms, ns;
            if (deadline == 0L) 
                ms = 0L;
            else if ((ns = deadline - System.nanoTime()) <= 0L) // 任务超时,终止自旋
                break;
            else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
                ms = 1L;
            if (tryCompensate(w)) {
                task.internalWait(ms);
                U.getAndAddLong(this, CTL, AC_UNIT);
            }
        }
        U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
    }
    return s;
}

任务窃取Scan

private ForkJoinTask<?> scan(WorkQueue w, int r) {
  WorkQueue[] ws; int m;
  //再次验证workQueue[]数组的初始化情况
  if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
    //获取当前扫描状态
    int ss = w.scanState;                     // initially non-negative
    
    //又一个死循环,注意到出口位置就好
    /随机一个起始位置,并赋值给k
    for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
      WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
      int b, n; long c;
      //如果k槽位不为空
      if ((q = ws[k]) != null) {
        //base-top小于零,并且任务q不为空
        if ((n = (b = q.base) - q.top) < 0 &&
            (a = q.array) != null) {      // non-empty
          //获取base的偏移量,赋值给i
          long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
          if ((t = ((ForkJoinTask<?>)
                    U.getObjectVolatile(a, i))) != null &&
              q.base == b) {
            //是active状态
            if (ss >= 0) {
              //更新WorkQueue中数组i索引位置为空,并且更新base的值
              if (U.compareAndSwapObject(a, i, t, null)) {
                q.base = b + 1;
                //n<-1,说明当前队列还有剩余任务,继续唤醒可能存在的其他线程
                if (n < -1)       // signal others
                  signalWork(ws, q);
                //直接返回任务
                return t;
              }
            }
            else if (oldSum == 0 &&   // try to activate
                     w.scanState < 0)
              tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
          }
          
          //如果获取任务失败,则准备换位置扫描
          if (ss < 0)                   // refresh
            ss = w.scanState;
          r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
          origin = k = r & m;           // move and rescan
          oldSum = checkSum = 0;
          continue;
        }
        checkSum += b;
      }
      
      //k一直在变,扫描到最后,如果等于origin,说明已经扫描了一圈还没扫描到任务
      if ((k = (k + 1) & m) == origin) {    // continue until stable
        if ((ss >= 0 || (ss == (ss = w.scanState))) &&
            oldSum == (oldSum = checkSum)) {
          if (ss < 0 || w.qlock < 0)    // already inactive
            break;
          //准备inactive当前工作队列
          int ns = ss | INACTIVE;       // try to inactivate
          //活动线程数AC减1
          long nc = ((SP_MASK & ns) |
                     (UC_MASK & ((c = ctl) - AC_UNIT)));
          w.stackPred = (int)c;         // hold prev stack top
          U.putInt(w, QSCANSTATE, ns);
          if (U.compareAndSwapLong(this, CTL, c, nc))
            ss = ns;
          else
            w.scanState = ss;         // back out
        }
        checkSum = 0;
      }
    }
  }
  return null;
}

总结

上面主要介绍了在日常开发中不太常用的两个功能点ForkJoin和Future,其作用是充分利用CPU多核的特点,来完成特殊的需求。特别是ForkJoin,虽然代码比较难以理解,但是其中的分治思想值得借鉴。但不是任何时候都可以使用ForkJoin的,当数据量不是特别大的时候,我们没有必要使用ForkJoin。在多线程工作时,经常会上下文切换,这个是比较耗时的,所以数据量不大的时候使用单线程比使用多线程快。