ForkJoin(分支合并)
什么是ForkJoin
ForkJoin
是一种基于工作窃取算法的并行编程模型,用于高效地处理递归式任务。在 Java 中,ForkJoinPool
是一个实现了这种并行编程模型的线程池。
在 ForkJoin
中,一个大的任务被递归地分解成多个小的子任务,并行地执行这些子任务,最终将它们的结果合并起来得到整个任务的结果。这种分解和合并的过程可以形成一棵树状结构,每个节点都代表一个子任务,而叶子节点则代表最小的可执行任务。
为了避免线程之间的竞争和同步开销,ForkJoin
使用了工作窃取算法。这种算法让每个线程都维护一个自己的任务队列,当一个线程执行完自己的任务后,它会从其他线程的队列中窃取一个任务来执行。这种方式可以保证每个线程都能保持忙碌,最大限度地利用 CPU 和内存资源。
ForkJoinPool
是一个实现了 ForkJoin
并行编程模型的线程池,它提供了一种高效、灵活和易于使用的方式来并行执行递归式任务。在 Java 7 中引入了 ForkJoin
和 ForkJoinPool
,并在 Java 8 中与 Stream
和 CompletableFuture
等新特性结合使用,使得 Java 开发人员更容易编写高效的并行代码。
怎么使用
在 ForkJoin
并行编程模型中,大的任务被递归分解成多个小的子任务,这些子任务被并行地执行,并最终合并它们的结果来得到整个任务的结果。为了实现这个过程,可以使用 RecursiveTask
或 RecursiveAction
类来表示可分解的任务和它们的子任务。
RecursiveTask
和 RecursiveAction
类都是抽象类,需要通过继承和实现其抽象方法来创建自定义的可分解任务。其中,RecursiveTask
类用于返回结果的任务,而 RecursiveAction
类用于不返回结果的任务。
下面是一个简单的示例,展示了如何使用 RecursiveTask
类将一个大的数组分解成多个小的子数组并计算它们的总和:
public class Main {
public static void main(String[] args) {
SumTask sumTask = new SumTask(new int[10000], 0, 10000);
System.out.println(sumTask.compute());
}
}
class SumTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 1000;
private final int[] array;
private final int start;
private final int end;
public SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
// 这是RecursiveTask类里的抽象compute方法,需要重写
@Override
protected Long compute() {
if (end - start <= THRESHOLD) {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
} else {
int mid = (start + end) / 2;
SumTask leftTask = new SumTask(array, start, mid);
SumTask rightTask = new SumTask(array, mid, end);
// 将左半边的数组交给新的线程去执行任务
leftTask.fork();
// 右边数组进行compute方法
long rightResult = rightTask.compute();
// 等待左边数组的运行结果
long leftResult = leftTask.join();
return leftResult + rightResult;
}
}
}
代码分析
这个示例中,SumTask
类继承了 RecursiveTask
类,并实现了其中的 compute()
方法。在 compute()
方法中,如果任务的大小小于或等于阈值 THRESHOLD
,则直接计算任务的总和并返回结果。否则,将任务分解成两个子任务,并使用 fork()
和 join()
方法将它们并行执行,并最终将它们的结果合并起来得到总和。
在这个例子中,任务被递归分解成多个小的子任务,并且通过 fork()
和 join()
方法并行地执行它们。这种方式可以最大限度地利用 CPU 和内存资源,提高计算效率和吞吐量。
我的SumTask类里并没有fork()
和join()
方法,那它是如何来的呢,又有啥作用呢?
点击RecursiveTask
进入源码, 你会发现它继承了ForkJoinTask
, 进到ForkJoinTask
源码。
可以发现:
fork方法
fork()
方法是 ForkJoinTask
类中的一个方法,它的作用是将当前的任务异步地分配给一个工作线程进行执行。具体来说,当当前任务需要进一步分解成子任务时,可以通过调用 fork()
方法将子任务提交到工作线程池中,并立即返回,继续执行其他的任务或者代码。这样可以避免当前线程被阻塞,提高程序的并行度和执行效率。
在 ForkJoinTask
类的 fork()
方法中,首先使用 Thread.currentThread()
方法获取当前线程,如果当前线程是一个 ForkJoinWorkerThread
类型的线程,说明当前线程已经加入到了线程池中,可以直接将当前任务提交到该线程的工作队列中。否则,如果当前线程不是一个 ForkJoinWorkerThread
类型的线程,说明当前线程不是从线程池中获取的,需要使用 ForkJoinPool.common.externalPush(this)
方法将任务提交到线程池的共享队列中。
fork()
方法的返回值是当前任务本身,这样可以将任务的提交和返回合并在一起,更加方便任务的编写和调用。同时,fork()
方法是一个 final
方法,不能被子类重写,确保了任务的提交和执行的一致性和可靠性。
ForkJoinTask的join()方法
用于返回任务的结果,join()
方法会等待当前任务的所有子任务完成并返回结果,然后将这些结果合并成一个整体结果,并返回给调用者。如果当前任务没有子任务,或者子任务已经全部执行完成并返回结果,则 join()
方法会立即返回当前任务的结果,不会被阻塞。
为什么不需要自己新建线程池
ForkJoinWorkerThread
维护了一个线程池ForkJoinPool
, 它会自己按照需求来创建销毁线程。
在使用 Fork/Join 框架时,用户不需要关心线程池的创建和管理细节,只需要编写好任务的递归分解和计算逻辑,然后将任务提交到框架中即可。框架会自动将任务分配给适当的线程池和工作线程来执行,并保证任务的执行效率和正确性。