Fork-Join分治编程

72,928 阅读1分钟

一、Fork-Join的运行流程

如下所示:

在JDK中并行执行框架Fork-Join使用了“工作窃取”算法,它是指某个线程从其他队列里窃取任务来执行。

二、实现Fork-Join大体分为两步:

第一步、分割任务

把大任务分割成子任务,如果子任务不够小,则继续往下分,直到分割出的子任务足够小。

通常会继承以下两个类,实现其compute方法。

  • RecursiveAction执行的任务是无返回值的,仅执行一次任务。
  • RecursiveTask执行的任务具有返回值的功能。

第二步、任务执行并返回结果

任务的执行需要通过 ForkJoinPool 来执行。

三、示例

用Fork-Join求一个亿级 Integer 数据量array的和

import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

/**
 * 亿级 Integer 数据量的一个 array 求和。
 * Created on 2020-5-24
 */
public class ArrayCountTask extends RecursiveTask<Long> {

    private final int start, end;
    private static final int grain = 100000;
    private final int[] data;

    public ArrayCountTask(int start, int end, int[] data) {
        this.start = start;
        this.end = end;
        this.data = data;
    }


    @Override
    protected Long compute() {
        long count = 0;
        if (end - start <= grain) {
            for (int i = start; i < end; i++) {
                count += data[i];
            }
        } else {
            int middle = (start + end) >>> 1;
            RecursiveTask left = new ArrayCountTask(start, middle, data);
            RecursiveTask right = new ArrayCountTask(middle, end, data);
            invokeAll(left, right);
            long leftJoin = (long) left.join();
            long rightJoin = (long) right.join();
            count = leftJoin + rightJoin;

        }
        return count;
    }

    public static void main(String[] args) {
        int total = 500000000;
        int[] data = new int[total];
        for (int i = 0; i < total; i++) {
            data[i] = new Random().nextInt(5);
        }

        long startTime = System.currentTimeMillis();
        long sum = 0;
        for (int i = 0; i < total; i++) {
            sum += data[i];
        }
        System.out.println("用普通方式计算的结果:sum = " + sum);
        System.out.println("耗时:" + (System.currentTimeMillis() - startTime));

        ArrayCountTask task = new ArrayCountTask(0, data.length, data);
        ForkJoinPool pool = new ForkJoinPool();

        startTime = System.currentTimeMillis();
        sum = pool.invoke(task);
        System.out.println("用ForkJoin的方式计算的结果:sum = " + sum);
        System.out.println("耗时:" + (System.currentTimeMillis() - startTime));
    }
}

计算结果:

计算所用八核CPU,小数据量性能提升不是特别明显。