【多线程】Java多线程基础(16)- 使用ForkJoin

49 阅读5分钟

ForkJoin(分支合并)

什么是ForkJoin

ForkJoin 是一种基于工作窃取算法的并行编程模型,用于高效地处理递归式任务。在 Java 中,ForkJoinPool 是一个实现了这种并行编程模型的线程池。

在 ForkJoin 中,一个大的任务被递归地分解成多个小的子任务,并行地执行这些子任务,最终将它们的结果合并起来得到整个任务的结果。这种分解和合并的过程可以形成一棵树状结构,每个节点都代表一个子任务,而叶子节点则代表最小的可执行任务。

为了避免线程之间的竞争和同步开销,ForkJoin 使用了工作窃取算法。这种算法让每个线程都维护一个自己的任务队列,当一个线程执行完自己的任务后,它会从其他线程的队列中窃取一个任务来执行。这种方式可以保证每个线程都能保持忙碌,最大限度地利用 CPU 和内存资源。

ForkJoinPool 是一个实现了 ForkJoin 并行编程模型的线程池,它提供了一种高效、灵活和易于使用的方式来并行执行递归式任务。在 Java 7 中引入了 ForkJoin 和 ForkJoinPool,并在 Java 8 中与 Stream 和 CompletableFuture 等新特性结合使用,使得 Java 开发人员更容易编写高效的并行代码。

怎么使用

在 ForkJoin 并行编程模型中,大的任务被递归分解成多个小的子任务,这些子任务被并行地执行,并最终合并它们的结果来得到整个任务的结果。为了实现这个过程,可以使用 RecursiveTask 或 RecursiveAction 类来表示可分解的任务和它们的子任务。

RecursiveTask 和 RecursiveAction 类都是抽象类,需要通过继承和实现其抽象方法来创建自定义的可分解任务。其中,RecursiveTask 类用于返回结果的任务,而 RecursiveAction 类用于不返回结果的任务。

下面是一个简单的示例,展示了如何使用 RecursiveTask 类将一个大的数组分解成多个小的子数组并计算它们的总和:

public class Main {
    public static void main(String[] args) {
        SumTask sumTask = new SumTask(new int[10000], 0, 10000);
        System.out.println(sumTask.compute());
    }
}
 class SumTask extends RecursiveTask<Long> {

    private static final int THRESHOLD = 1000;

    private final int[] array;
    private final int start;
    private final int end;

    public SumTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    // 这是RecursiveTask类里的抽象compute方法,需要重写
    @Override
    protected Long compute() {
        if (end - start <= THRESHOLD) {
            long sum = 0;
            for (int i = start; i < end; i++) {
                sum += array[i];
            }
            return sum;
        } else {
            int mid = (start + end) / 2;
            SumTask leftTask = new SumTask(array, start, mid);
            SumTask rightTask = new SumTask(array, mid, end);
            // 将左半边的数组交给新的线程去执行任务
            leftTask.fork();
            // 右边数组进行compute方法
            long rightResult = rightTask.compute();
            // 等待左边数组的运行结果
            long leftResult = leftTask.join();
            return leftResult + rightResult;
        }
    }
}

代码分析

这个示例中,SumTask 类继承了 RecursiveTask 类,并实现了其中的 compute() 方法。在 compute() 方法中,如果任务的大小小于或等于阈值 THRESHOLD,则直接计算任务的总和并返回结果。否则,将任务分解成两个子任务,并使用 fork() 和 join() 方法将它们并行执行,并最终将它们的结果合并起来得到总和。

在这个例子中,任务被递归分解成多个小的子任务,并且通过 fork() 和 join() 方法并行地执行它们。这种方式可以最大限度地利用 CPU 和内存资源,提高计算效率和吞吐量。

我的SumTask类里并没有fork()join()方法,那它是如何来的呢,又有啥作用呢?

点击RecursiveTask进入源码, 你会发现它继承了ForkJoinTask, 进到ForkJoinTask源码。

可以发现:

fork方法

fork() 方法是 ForkJoinTask 类中的一个方法,它的作用是将当前的任务异步地分配给一个工作线程进行执行。具体来说,当当前任务需要进一步分解成子任务时,可以通过调用 fork() 方法将子任务提交到工作线程池中,并立即返回,继续执行其他的任务或者代码。这样可以避免当前线程被阻塞,提高程序的并行度和执行效率。

在 ForkJoinTask 类的 fork() 方法中,首先使用 Thread.currentThread() 方法获取当前线程,如果当前线程是一个 ForkJoinWorkerThread 类型的线程,说明当前线程已经加入到了线程池中,可以直接将当前任务提交到该线程的工作队列中。否则,如果当前线程不是一个 ForkJoinWorkerThread 类型的线程,说明当前线程不是从线程池中获取的,需要使用 ForkJoinPool.common.externalPush(this) 方法将任务提交到线程池的共享队列中。

fork() 方法的返回值是当前任务本身,这样可以将任务的提交和返回合并在一起,更加方便任务的编写和调用。同时,fork() 方法是一个 final 方法,不能被子类重写,确保了任务的提交和执行的一致性和可靠性。

ForkJoinTask的join()方法

用于返回任务的结果,join() 方法会等待当前任务的所有子任务完成并返回结果,然后将这些结果合并成一个整体结果,并返回给调用者。如果当前任务没有子任务,或者子任务已经全部执行完成并返回结果,则 join() 方法会立即返回当前任务的结果,不会被阻塞。

为什么不需要自己新建线程池

ForkJoinWorkerThread维护了一个线程池ForkJoinPool, 它会自己按照需求来创建销毁线程。

在使用 Fork/Join 框架时,用户不需要关心线程池的创建和管理细节,只需要编写好任务的递归分解和计算逻辑,然后将任务提交到框架中即可。框架会自动将任务分配给适当的线程池和工作线程来执行,并保证任务的执行效率和正确性。