分治思想之ForkJoin

1,290 阅读2分钟

前言

  当我们面对需要同时执行多个任务的情况时,往往需要耗费大量的时间和精力来编写复杂的并发代码。但是有一种技术可以帮助我们轻松地处理这种情况。通过forkJoin,我们可以简单地实现多个任务的并行执行,从而提高应用程序的性能和响应能力。在本文中,我们将深入探讨forkJoin的工作原理和使用方法。

分治

  fork-join模式是基于分治思想的并行计算模式之一。该模式将一个大的任务分割成多个小的子任务,然后并行执行这些子任务,最后将它们的结果合并起来得到最终的结果。在这个过程中,每个子任务的执行可以进一步分解为更小的子任务,直到达到某个阈值,这时候任务将被串行执行。这种递归的分治思想使得fork-join模式可以有效地利用计算机的多核处理能力,从而提高程序的性能和效率。

分治思想算法

归并排序

归并排序是一种基于分治思想的排序算法。其核心思想是将一个数组分成两个子数组,分别对其进行排序后再将其合并起来。具体实现过程如下:

  1. 分解:将一个数组分成两个子数组,分别对其进行排序。可以使用递归来实现这一步骤。
  2. 合并:将排序后的两个子数组合并成一个有序的数组。
  3. 递归:对两个子数组递归进行分解和排序,直到每个子数组的长度为1。

时间复杂度为O(nlogn)。

image.png

public class Merge {

    public static void main(String[] args) {
        int[] arr = { 5, 2, 8, 4, 7, 1, 3, 6 };
        mergeSort(arr, 0, arr.length - 1);
        System.out.println(Arrays.toString(arr));
    }

    /**
     * 拆分
     * @param arr 数组
     * @param left 数组第一个元素
     * @param right 数组最后一个元素
     */
    public static void mergeSort(int[] arr,int left,int right){
        if (left < right) {
            int mid = (left + right) / 2;
            mergeSort(arr, left, mid);
            mergeSort(arr, mid + 1, right);
            merge(arr, left, mid, right);
        }
    }


    /**
     * 合并
     * @param arr 数组
     * @param left 数组第一个元素
     * @param mid 数组分割的元素
     * @param right 数组最后一个元素
     */
    public static void merge(int[] arr,int left,int mid,int right){
        //创建临时数组,用于存放合并后的数组
        int[] temp = new int[right - left + 1];
        //左面数组的起始指针
        int i = left;
        //右面数组的起始指针
        int j = mid + 1;
        //临时数组的下标
        int k = 0;
        //数组左面和数组右面都还有值就去遍历比较赋值
        while (i <= mid && j <= right) {
            //数组左面的值小于或者等于数组右面的值就把左面的值赋值到临时数组中
            //并且把左面的指针和临时数组的指针+1
            //否则相反
            if (arr[i] <= arr[j]) {
                temp[k] = arr[i];
                k++;
                i++;
            } else {
                temp[k] = arr[j];
                k++;
                j++;
            }
        }
        //把剩余数组塞进去
        while (i <= mid) {
            temp[k] = arr[i];
            k++;
            i++;
        }
        while (j <= right) {
            temp[k] = arr[j];
            k++;
            j++;
        }
        //讲临时数组中的元素拷贝会回原数组中
        for (int p = 0; p < temp.length; p++) {
            arr[left + p] = temp[p];
        }
    }
}

快速排序

快速排序(Quick Sort)是一种基于分治思想的排序算法,它采用了递归的方式将一个大的数组分解成多个子数组,分别进行排序后再将它们合并起来。其基本思想是选取一个基准元素,将数组中小于该元素的值放在左边,大于该元素的值放在右边,然后递归地对左右两个子数组进行排序。具体的步骤如下:

  1. 选取一个基准元素(通常是数组中的第一个元素)。
  2. 将数组中小于该元素的值放在左边,大于该元素的值放在右边。
  3. 对左右两个子数组分别递归进行快速排序。
  4. 合并左右两个已排序的子数组。

快速排序的时间复杂度为O(nlogn),它是一种原地排序算法,不需要额外的存储空间,因此空间复杂度为O(1)。虽然快速排序的最坏时间复杂度为O(n^2),但是在实际应用中,快速排序的平均时间复杂度和最好时间复杂度均为O(nlogn),因此是一种非常高效的排序算法

快速排序.png

public class QuickSort {

    public static void main(String[] args) {
        int[] arr = { 5, 2, 8, 4, 7, 1, 3, 6 };
        quickSort(arr, 0, arr.length - 1);
        System.out.println(Arrays.toString(arr));
    }


    public static void quickSort(int[] arr,int left,int right){
        if(left<right){
            int pivotIndex  = partition(arr, left, right);
            quickSort(arr,left,pivotIndex-1);
            quickSort(arr,pivotIndex+1,right);
        }
    }

    public static int partition(int[] arr,int left,int right){
        //取第一个元素为基准元素
        int pivot = arr[left];
        int i = left+1;
        int j = right;
        while (i<=j){
            //当前指针位置数小于基准元素就继续移动指针直到遇到大的
            while (i<=j && arr[i] < pivot){
                i++;
            }
            //当前指针位置数大于基准元素就继续移动指针直到遇到小的
            while (i<=j && arr[j] > pivot){
                j--;
            }
            if(i<j){
                //交换元素位置
                swap(arr,i,j);
                i++;
                j--;
            }
        }
        //跳出循环说明i和j相遇,j的值一定是大于基准元素,要和基准元素进行交换
        swap(arr,left,j);
        //返回基准元素位置
        return j;
    }

    public static void swap(int[] array, int i, int j) {
        int temp = array[i];
        array[i] = array[j];
        array[j] = temp;
    }
}

Fork/Join

  Fork/Join框架的主要组成部分是ForkJoinPool、ForkJoinTask。ForkJoinPool是一个线程池,它用于管理ForkJoin任务的执行。ForkJoinTask是一个抽象类,用于表示可以被分割成更小部分的任务。

ForkJoinPool

ForkJoinPool是Fork/Join框架中的线程池类,它用于管理Fork/Join任务的线程。ForkJoinPool类包括一些重要的方法,例如submit()、invoke()、shutdown()、awaitTermination()等,用于提交任务、执行任务、关闭线程池和等待任务的执行结果。ForkJoinPool类中还包括一些参数,例如线程池的大小、工作线程的优先级、任务队列的容量等,可以根据具体的应用场景进行设置。

构造器

ForkJoinPool中有四个核心参数,用于控制线程池的并行数、工作线程的创建、异常处理和模式指定等。

  • int parallelism:指定并行级别(parallelism level)。ForkJoinPool将根据这个设定,决定工作线程的数量。如果未设置的话,将使用Runtime.getRuntime().availableProcessors()来设置并行级别;
  • ForkJoinWorkerThreadFactory factory:ForkJoinPool在创建线程时,会通过factory来创建。注意,这里需要实现的是ForkJoinWorkerThreadFactory,而不是ThreadFactory。如果你不指定factory,那么将由默认的 DefaultForkJoinWorkerThreadFactory负责线程的创建工作;
  • UncaughtExceptionHandler handler:指定异常处理器,当任务在运行中出错时,将由设定的处理器处理;
  • boolean asyncMode:设置队列的工作模式。当asyncMode为true时,将使用先进先出队列,而为false时则使用后进先出的模式。

工作窃取法

在Fork-Join模式中,任务被分配给一个线程池中的工作线程来执行。当一个工作线程执行完自己分配的任务后,它可以从其他工作线程的任务队列中偷取(Steal)任务来执行,这就是所谓的工作窃取(Work Stealing)。

使用

public class SumTask extends RecursiveTask<Integer> {
    private static final int THRESHOLD = 1000;
    private int[] array;
    private int start;
    private int end;

    public SumTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Integer compute() {
        if (end - start <= THRESHOLD) {
            int sum = 0;
            for (int i = start; i < end; i++) {
                sum += array[i];
            }
            return sum;
        } else {
            int mid = (start + end) / 2;
            SumTask leftTask = new SumTask(array, start, mid);
            SumTask rightTask = new SumTask(array, mid, end);
            leftTask.fork();
            rightTask.fork();
            int leftResult = leftTask.join();
            int rightResult = rightTask.join();
            return leftResult + rightResult;
        }
    }

    public static void main(String[] args) {
        int[] array = new int[10000];
        for (int i = 0; i < array.length; i++) {
            array[i] = i + 1;
        }
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        SumTask task = new SumTask(array, 0, array.length);
        int result = forkJoinPool.invoke(task);
        System.out.println(result);
    }
}