《Java7并发编程实战手册》学习笔记(六)——Fork/Join框架

753 阅读19分钟

此篇博客为个人学习笔记,如有错误欢迎大家指正

本次内容

  1. 创建Fork/Join线程池
  2. 合并任务的结果
  3. 异步运行任务
  4. 在任务中抛出异常
  5. 取消任务

Java为我们提供了ExecutorService接口的另一种实现——Fork/Join框架(分解/合并框架),这个框架能帮助我们更简单的用分治技术解决问题。使用Fork/Join框架,在执行一个任务时,我们首先判断这个任务的规模是否大于我们制定的标准,如果大于就将这个任务分解(fork)为规模更小的任务去执行;最后再将执行完成小任务层层合并(join)为大任务并返回,原理图如下:

Fork/Join框架与我们之前使用的执行器框架的主要区别在于前者实现了工作窃取算法。当我们使用join()方法使一个主任务等待它所创建的子任务完成时,执行任务的线程(工作者线程)并不会因为等待其他任务的完成而进入休眠状态,而是随机的去其他线程所维护的双端队列末尾取出一个任务来执行,这就极大的提升了工作效率。当然,为了达到上述目标,在使用Fork/Join框架时有以下限制:

  • 任务只能使用fork()join()等一些专门为Fork/Join框架准备的方法进行同步。如果使用了其他同步机制,工作者线程会真正的进入阻塞状态并且不会窃取其他线程的任务来执行
  • 任务不能执行I/O操作
  • 任务不可以抛出非运行时异常

Fork/Join框架的核心是由以下两个类组成的:

  • ForkJoinPool:这个类也实现了ExecutorExecutorService接口,和我们之前使用过的ThreadPoolExecutor类有些类似,主要区别在于这个类实现了工作窃取算法。获取ForkJoinPool对象的方法主要有以下几种,我们可以根据不同的需求进行选择:

    1. 首先是ForkJoinPool类的构造方法:

      • ForkJoinPool():无参构造方法,调用此方法获得的ForkJoinPool对象将执行默认的配置。其并行级别为当前JVM可以调用的CPU内核数量
      • ForkJoinPool(int parallelism):通过这个构造方法可以指定线程池的并行级别,但是我们传入的参数应该是大于0且小于等于JVM可以调用的CPU内核数量的
      • ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode) 此方法参数较多,下面会分别记录:
        1. parallelism:并行级别。Fork/Join框架将根据这个参数来设定框架内并行执行的线程数。注意,并不是框架中最大的线程数量
        2. factory:线程工厂。我们可以编写自己的Fork/Join县城工厂类,和之构造的线程工厂不同。构造Fork/Join线程工厂类需要我们实现ForkJoinWorkerThreadFactory接口而不是ThreadFactory接口
        3. handler:异常捕获处理器。当执行中的任务向上抛出异常时,就会被处理器捕获
        4. asyncMode:工作模式。Fork/Join框架中的每一个工作线程都维护着一个双端队列用于装载任务,参数为true则表示队列中的任务先进先出(FIFO),为false则表示后进先出(LIFO)
    2. ForkJoinPool类的静态方法commonPool()同样可以获得ForkJoinPool对象。值得注意的是,调用此方法获得的是Java预定义的线程池,这可以减少资源的消耗因为我们不再需要每提交一个任务就创建一个新的线程池了,也就是说每次我们调用此方法所获得的对象引用实际上都指向同一个线程池,可以发现执行下面的代码打印出的值将为true

          ForkJoinPool forkJoinPool1 = ForkJoinPool.commonPool();
          ForkJoinPool forkJoinPool2 = ForkJoinPool.commonPool();
          System.out.println(forkJoinPool1 == forkJoinPool2);
      

      另外,在调用经此方法获得的ForkJoinPool对象的shutdown()方法时,线程池并不会关闭

    3. 使用Executors类的静态方法newWorkStealingPool()或者此方法的另一种实现newWorkStealingPool(int parallelism)

  • ForkJoinTask:此类实现了Future接口,是在ForkJoinPool中执行的任务的基类。为了使用Fork/Join框架执行任务,通常情况下我们需要实现以下两个ForkJoinTask子类的其中一个

    • RecursiveAciton:用于任务没有返回结果的场景
    • RecursiveTask:用于任务有返回结果的场景

    在继承上面两个类后,我们最好在自己的类中加上这样一个属性:
    private static final long serialVersionUID = 1L;
    这是因为RecursiveActionRecursiveTask类均继承了ForkJoinTask类,而ForkJoinTask类又实现了Serializable接口。如果我们不显示的声明这个属性,那么Java会根据当前类的属性、方法给出一个默认值。当我们修改了类的属性或方法后,这个值会发生变化。这样一来,我们在将修改之前进行过序列化的类进行反序列化时就会出现错误。所以我们最好显示的声明这一属性。

1.创建Fork/Join线程池

使用Fork/Join框架,我们最好参考JavaAPI手册为我们推荐的代码结构

if (problem size > default size) {
    tasks = divide(task);
    execute(tasks);
} else {
    resolve problem using another algorthm;
}

下面是在此小节中需要了解的方法:

  • ForkJoinPool类:
    1. execute(ForkJoinTask<?> task):无返回值。调用此方法向线程池提交一个任务,注意这个方法是异步的,调用后线程不会等待而是直接向下执行。execute(Runnable task)是另一种实现,提交一个Runnable类型的任务给线程池,在这种情况下线程池不会使用工作窃取算法
    2. invoke(ForkJoinTask<T> task):此方法最好和execute(ForkJoinTask<?> task)方法对比来看。区别在与这个方法是同步的,调用后会直到任务执行结束后才返回。返回值即为任务返回的结果
    3. 因为ForkJoinPool类实现了ExecutorService接口,所以也实现了invokeAll()invokeAny()方法。这些方法之前都已经使用过,参数为Callable类型的任务列表。但是当我们向ForkJoinPool发送Runnable或Callable类型的任务时,线程池并不会使用工作窃取算法,因此我们不推荐这样做
  • ForkJoinTask类:
    1. adapt():传入一个Runnable或Callable对象,返回一个ForkJoinTask对象
    2. invokeAll():传入ForkJoinTask对象列表或数个ForkJoinTask对象。这个方法是同步的,当主任务在等待子任务时,执行主任务的工作线程会开始执行另一个等待执行的任务。值得注意的是,因传入参数不同这个方法的返回值也有所区别。直接传入ForkJoinTask对象的话此方法没有返回值;传入ForkJoinTask对象列表的话返回值也为传入的ForkJoinTask对象列表,并且经过调试我们可以发现传入和返回的两个列表对象的引用实际是指向同一个对象

范例实现

在这个范例中,我们将对所有商品使用分治技术进行涨价操作。由于任务不需要有返回值,我们的任务类继承了RecursiveAciton
商品类:

package day06.code_1;

public class Product {

    //商品名称
    private String name;

    //商品价格
    private double price;

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public double getPrice() {
        return price;
    }

    public void setPrice(double price) {
        this.price = price;
    }
}

商品列表生成类:

package day06.code_1;

import java.util.ArrayList;
import java.util.List;

public class ProductListGenerator {

    //根据传入的大小创建一个产品集合
    public List<Product> generate(int size) {
        //创建一个集合
        ArrayList<Product> products = new ArrayList<>();
        for (int i = 0; i < size; i++) {
            //创建产品
            Product product = new Product();
            //设置名字
            product.setName("Product " + i);
            //统一设置初始价格为10,方便检查程序的正确性
            product.setPrice(10);
            //装入集合
            products.add(product);
        }
        //返回集合
        return products;
    }

}

任务类:

package day06.code_1;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.RecursiveAction;

public class Task extends RecursiveAction {

    //必备参数
    private static final long serialVersionUID = 1L;

    //产品集合
    private List<Product> products;

    //起始位置
    private int first;

    //终止位置
    private int last;

    //价格增百分比
    private double increment;

    public Task(List<Product> products, int first, int last, double increment) {
        this.products = products;
        this.first = first;
        this.last = last;
        this.increment = increment;
    }

    @Override
    protected void compute() {
        //如果任务数量小于10
        if (last - first < 10) {
            //执行涨价操作
            updatePrices();
        } else {
            //如果任务数量大于10则将任务均分
            int middle = (first + last) / 2;
            //打印分割任务提示语
            System.out.printf("Task: Pending tasks:%s\n",
                    getQueuedTaskCount());
            //根据新分配的范围创建两个任务
            Task t1 = new Task(products, first, middle + 1, increment);
            Task t2 = new Task(products, middle + 1, last, increment);
            //执行
            invokeAll(t1, t2);
        }
    }

    private void updatePrices() {
        //遍历集合为每一个商品做涨价操作
        for (int i = first; i < last; i++) {
            Product product = products.get(i);
            product.setPrice(product.getPrice() * (1 + increment));
        }
    }
}

main方法:

package day06.code_1;

import java.util.List;
import java.util.concurrent.ForkJoinPool;

public class Main {

    public static void main(String[] args) {
        //创建产品生成对象
        ProductListGenerator generator = new ProductListGenerator();
        //通过产品生成器得到大小为10000的产品集合
        List<Product> products = generator.generate(10000);
        //创建一个任务
        Task task = new Task(products, 0, 10000, 0.20);
        //创建线程池
        ForkJoinPool pool = new ForkJoinPool();
        //调用线程池的方法执行任务
        pool.execute(task);
        do {
            //打印线程池中当前正在执行任务的线程数量
            System.out.printf("Main: Thread Count: %d\n",
                    pool.getActiveThreadCount());
            //打印线程池中窃取的工作数量
            System.out.printf("Main: Thread Steal: %d\n",
                    pool.getStealCount());
            //打印线程池的并行级别
            System.out.printf("Main: Parallelism: %d\n",
                    pool.getParallelism());
            //休眠5秒
            try {
                Thread.sleep(5);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            //等待任务结束
        } while (!task.isDone());
        //关闭线程池
        pool.shutdown();
        //判断任务是否抛出了异常
        if (task.isCompletedNormally()) {
            //打印任务无异常完成的提示信息
            System.out.printf("Main: The process has completed normally\n");
        }
        //检查商品是否已正确涨价
        for (int i = 0; i < products.size(); i++) {
            Product product = products.get(i);
            if (product.getPrice() != 12) {
                System.out.printf("Product %s: %f\n",
                        product.getName(), product.getPrice());
            }
        }
        //打印程序结束提示语
        System.out.println("Main: End of the program\n");
    }

}

2.合并任务的结果

使用Fork/Join框架执行带有返回值的任务时必须继承RecursiveTask类并使用JavaAPI文档推荐的结构:

if (problem size > default size) {
    tasks = Divide(task);
    execute(tasks);
    groupResults();
    return results;
} else {
    resolve problem;
    return result;
}

以下几个ForkJoinTask类中的方法我们需要了解:

  1. fork():无参数、返回值。此方法用于向线程池异步的发送一个任务,发送完成后将会立刻返回并向下执行
  2. get():一直等待直到获得任务返回的结果。另一种实现为get(long timeout, TimeUnit unit),如果等待时间超时后任务还未返回结果,则方法直接返回null。get方法可以被中断。如果任务抛出运行时异常,get方法会返回ExecutionException异常
  3. join():一直等待直到获得任务返回的结果。此方法和get()方法有些类似,区别在于join()方法不能被中断。如果中断调用了该方法的线程,join()方法将抛出InterruptedException异常。另外,任务抛出运行时异常时,join()方法会返回RuntimeWxception异常
    以上三个方法中,第一个与第二或三个方法组合经常用来实现异步运行任务这一需求

范例实现

在这个范例中,我们将统计一个指定词汇在文档中出现的次数。我们会不断切割任务直到每个任务仅搜索100个以内的词汇
DocumentMock(文档生成类):

package day06.code_2;


import java.util.Random;

public class DocumentMock {

    //从以下词汇中选择词语组成文档
    private String words[] = {
            "the", "hello", "goodbye", "packt", "java",
            "thread", "pool", "random", "class", "main"
    };

    public String[][] generateDocument(int numLines, int numWords, String word) {
        //记录指定词汇出现的次数,便于后期判断程序对错
        int counter = 0;
        //创建二维数组
        String[][] document = new String[numLines][numWords];
        //随机数生成器
        Random random = new Random();
        //填充数组
        for (int i = 0; i < numLines; i++) {
            for (int j = 0; j < numWords; j++) {
                //随机选取词汇并填充
                int index = random.nextInt(words.length);
                document[i][j] = words[index];
                //如果是指定词汇,计数器加一
                if (document[i][j] == word) {
                    counter++;
                }

            }
        }
        //打印指定词汇出现的次数
        System.out.printf("DocumentMock: The word appears " +
                "%d times in the document\n", counter);
        //返回文档
        return document;
    }

}

DocumentTask(文档任务类):

package day06.code_2;

import java.util.ArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.RecursiveTask;

public class DocumentTask extends RecursiveTask<Integer> {

    //必备参数
    private static final long serialVersionUID = 1L;

    //文档
    private String[][] document;

    //起始、结束位置
    private int start, end;

    //待查找的词汇
    private String word;

    public DocumentTask(String[][] document, int start, int end, String word) {
        this.document = document;
        this.start = start;
        this.end = end;
        this.word = word;
    }

    @Override
    protected Integer compute() {
        //初始化计数器
        int result = 0;
        //如果行数小于10
        if (end - start < 10) {
            //处理每一行的数据
            result = processLines(document, start, end, word);
        } else {
            //行数大于10则进行任务分割
            int mid = (start + end) / 2;
            DocumentTask task1 = new DocumentTask(document, start, mid, word);
            DocumentTask task2 = new DocumentTask(document, mid, end, word);
            //提交任务(同步)
            invokeAll(task1, task2);
            try {
                //处理子任务返回的结果
                result = groupResults(task1.get(), task2.get());
            } catch (ExecutionException | InterruptedException e) {
                e.printStackTrace();
            }
        }
        //返回结果
        return result;
    }

    //将子任务结果相加后返回
    private int groupResults(Integer number1, Integer number2) {
        return number1 + number2;
    }

    private int processLines(String[][] document, int start, int end, String word) {
        //创建装载行任务的集合
        ArrayList<LineTask> tasks = new ArrayList<>();
        //创建行任务
        for (int i = start; i < end; i++) {
            LineTask task = new LineTask(document[i], 0, document[i].length, word);
            tasks.add(task);
        }
        //执行所有任务
        invokeAll(tasks);
        //初始化计数器
        int result = 0;
        //从任务中获取结果
        for (int i = 0; i < tasks.size(); i++) {
            LineTask task = tasks.get(i);
            try {
                result = result + task.get();
            } catch (ExecutionException | InterruptedException e) {
                e.printStackTrace();
            }
        }
        //返回结果
        return result;
    }

}

LineTask(单行任务类):

package day06.code_2;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.RecursiveTask;

public class LineTask extends RecursiveTask<Integer> {

    //必备参数
    private static final long serialVersionUID = 1L;

    //行数据
    private String line[];

    //起始、结束位置
    private int start, end;

    //待查找的词汇
    private String word;

    public LineTask(String[] line, int start, int end, String word) {
        this.line = line;
        this.start = start;
        this.end = end;
        this.word = word;
    }

    @Override
    protected Integer compute() {
        //初始化计数器
        int result = 0;
        //如果一行的数据小于100
        if (end - start < 100) {
            //查找指定词汇的数量
            result = count(line, start, end, word);
        } else {
            //分割任务
            int mid = (start + end) / 2;
            LineTask task1 = new LineTask(line, start, mid, word);
            LineTask task2 = new LineTask(line, mid, end, word);
            //执行
            invokeAll(task1, task2);
            //获取子任务的结果
            try {
                result = groupResults(task1.get(), task2.get());
            } catch (ExecutionException | InterruptedException e) {
                e.printStackTrace();
            }
        }
        return result;
    }

    //将子任务结果相加后返回
    private Integer groupResults(Integer number1, Integer number2) {
        return number1 + number2;
    }

    private int count(String[] line, int start, int end, String word) {
        //初始化计数器
        int counter = 0;
        //查找每一个元素是否为指定的词汇
        for (int i = start; i < end; i++) {
            if (line[i].equals(word)) {
                counter++;
            }
        }
        //休眠10毫秒
        try {
            Thread.sleep(10);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        //返回结果
        return counter;
    }
}

main方法:

package day06.code_2;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

public class Main {

    public static void main(String[] args) {
        //创建文档生成器
        DocumentMock mock = new DocumentMock();
        //生成文档
        String[][] document = mock.generateDocument(100, 1000, "the");
        //创建文档搜素任务
        DocumentTask task = new DocumentTask(document, 0, 100, "the");
        //创建线程池
        ForkJoinPool pool = new ForkJoinPool();
        //异步执行文档搜索任务
        pool.execute(task);
        //每隔一秒打印一次线程池的状态直到任务执行结束
        do {
            System.out.println("****************************************");
            //并行级别
            System.out.printf("Main: Parallelism: %d\n",
                    pool.getParallelism());
            //正在工作的线程
            System.out.printf("Main: Active Threads: %d\n",
                    pool.getActiveThreadCount());
            //已提交的任务数量(不包括尚未执行的)
            System.out.printf("Main: Task Count: %d\n",
                    pool.getQueuedTaskCount());
            //窃取工作的数量
            System.out.printf("Main: Steal Count: %d\n",
                    pool.getStealCount());
            System.out.println("****************************************");
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while (!task.isDone());
        //关闭线程池
        pool.shutdown();
        //打印待查找关键词的数量
        try {
            System.out.printf("Main: The word appears %d in the document",
                    task.get());
        } catch (ExecutionException | InterruptedException e) {
            e.printStackTrace();
        }
    }

}

3.异步运行任务

当我们采用异步的方式向线程池发送任务时,方法将立即返回,代码也将继续向下执行,不过我们提交的任务会继续执行。在第二小节中我们已经将异步运行任务的相关方法记录了,就不在此赘述

范例实现

在这个范例中我们将查找指定的文件夹内是否有我们要查找的文件
FolderProcessor类(文件查找任务类):

package day06.code_3;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.RecursiveTask;

public class FolderProcessor extends RecursiveTask<List<String>> {

    //必备参数
    private static final long serialVersionUID = 1L;

    //文件夹路径
    private String path;

    //文件后缀名
    private String extension;

    public FolderProcessor(String path, String extension) {
        this.path = path;
        this.extension = extension;
    }

    @Override
    protected List<String> compute() {
        //创建一个集合用于装载文件路径
        ArrayList<String> list = new ArrayList<>();
        //创建集合用于装载任务
        ArrayList<FolderProcessor> tasks = new ArrayList<>();
        //创建文件对象
        File file = new File(path);
        //得到文件夹下的全部文件
        File[] content = file.listFiles();
        //判断是否为空
        if (content != null) {
            //遍历集合
            for (int i = 0; i < content.length; i++) {
                //如果是文件夹就创建任务继续查找
                if (content[i].isDirectory()) {
                    FolderProcessor task = new FolderProcessor
                            (content[i].getAbsolutePath(), extension);
                    //异步执行任务
                    task.fork();
                    //将任务保存进集合
                    tasks.add(task);
                } else {
                    //检查文件是否符合要求,符合的话就装入集合
                    if (checkFile(content[i].getName())) {
                        list.add(content[i].getAbsolutePath());
                    }
                }
            }
        }
        //如果文件集合容量超过50了就打印
        if (tasks.size() > 50) {
            System.out.printf("%s: %d tasks run\n",
                    file.getAbsolutePath(), tasks.size());
        }
        //整合子任务返回的结果
        addResultsFromTasks(list, tasks);
        //返回结果
        return list;
    }

    private void addResultsFromTasks(List<String> list, List<FolderProcessor> tasks) {
        //遍历任务集合
        for (FolderProcessor item : tasks) {
            //取得所有子任务返回的结果并装进集合中
            list.addAll(item.join());
        }
    }

    //检查文件后缀名是否符合要求
    private boolean checkFile(String name) {
        return name.endsWith(extension);
    }
}

main方法:

package day06.code_3;

import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

public class Main {

    public static void main(String[] args) {
        //创建线程池
        ForkJoinPool pool = new ForkJoinPool();
        //创建三个任务并异步执行
        FolderProcessor system = new FolderProcessor("C:\\", "exe");
        FolderProcessor program = new FolderProcessor("D:\\", "exe");
        FolderProcessor data = new FolderProcessor("F:\\", "exe");
        pool.execute(system);
        pool.execute(program);
        pool.execute(data);
        //在任务没有都结束之前不断循环打印线程池的信息
        do {
            System.out.println("***************************************");
            System.out.printf("Main: Parallelism: %d\n",
                    pool.getParallelism());
            System.out.printf("Main: Active Threads: %d\n",
                    pool.getActiveThreadCount());
            System.out.printf("Main: Task Count: %d\n",
                    pool.getQueuedTaskCount());
            System.out.printf("Main: Steal Count: %d\n",
                    pool.getStealCount());
            System.out.println("***************************************");
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } while ((!system.isDone()) || (!program.isDone()) || (!data.isDone()));
        //关闭线程池
        pool.shutdown();
        //获取并打印每一个任务返回的结果
        List<String> result;
        result = system.join();
        System.out.printf("System: %d files found\n", result.size());
        result = program.join();
        System.out.printf("Program: %d files found\n", result.size());
        result = data.join();
        System.out.printf("Data: %d files found\n", result.size());
    }

}

4.在任务中抛出异常

ForkJoinTask类的compute()方法中不允许抛出非运行时异常,但是我们仍可以抛出运行时异常。然而,当任务抛出运行时异常时,ForkJoinPoolForkJoinTask类的行为和我们期待的并不相同。程序不会结束运行,异常信息也不会打印出来。只有当我们去获取任务的结果时,异常才会抛出。需要注意的是,当子任务抛出异常时,它的父任务也会受到影响。以下ForkJoinTask类中的几个方法会对我们获取异常信息有一定帮助:

  1. isCompletedAbnormally():如果主任务或它的子任务抛出了异常,此方法将返回true
  2. isCompletedNormally():如果主任务及它的子任务均正常完成了,此方法返回true
  3. getException():调用此方法来获得任务抛出的异常对象

范例实现

在这个范例中,我们将对一个数组进行搜索。搜索任务中如果包含了索引3,则抛出运行时异常
Task类:

package day06.code_4;

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

public class Task extends RecursiveTask<Integer> {

    //数组
    private int[] array;

    //起始、终止位置
    private int start, end;

    public Task(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Integer compute() {
        //打印搜索范围的信息
        System.out.printf("Task: Start from %d to %d\n",
                start, end);
        //如果搜索范围小于10
        if (end - start < 10) {
            //判断是否包含索引三
            if ((3 > start) && (3 < end)) {
                //抛出运行时异常
                throw new RuntimeException("This task throws an Exception: " +
                        "Task from " + start + " to " + end);
            }
            //休眠1秒
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        } else {
            //分割任务
            int mid = (start + end) / 2;
            Task task1 = new Task(array, start, mid);
            Task task2 = new Task(array, mid, end);
            //执行
            invokeAll(task1, task2);
        }
        //打印任务结束语
        System.out.printf("Task: End from %d to %d\n", start, end);
        return 0;

    }
}

main方法:

package day06.code_4;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

public class Main {

    public static void main(String[] args) {
        //创建数组
        int[] array = new int[100];
        //创建任务
        Task task = new Task(array, 0, 100);
        //创建线程池
        ForkJoinPool pool = new ForkJoinPool();
        //执行任务
        pool.execute(task);
        //关闭线程池
        pool.shutdown();
        //休眠,直至线程池中的任务全部完成
        try {
            pool.awaitTermination(1, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        //判断任务是否存在异常
        if (task.isCompletedAbnormally()) {
            //打印异常提示语
            System.out.printf("Main: An exception has ocurred\n");
            //打印获取到的异常对象
            System.out.printf("Main: %s\n", task.getException());
        }
        //打印任务结果
        System.out.printf("Main: Result: %d", task.join());
    }
}

5.取消任务

ForkJoinTask类提供了cancel(boolean mayInterruptIfRunning)方法来达到取消任务的目的。和之前我们用到过的FutureTask类不同的是,ForkJoinTask类的cancel()方法只能取消未被执行的任务。JavaAPI文档指出,在ForkJoinTask类的默认实现中,传入的参数并没有起到作用,这就导致已经开始执行和已经执行结束的任务都不能被取消。取消成功返回true,否则返回false。另外,ForkJoinPool类中并没有提供任务用于取消任务的方法。

范例实现

在这个范例中,我们将在数组中寻找一个数字,找到后就取消其他的搜索任务。
ArrayGenerator(数组生成类):

package day06.code_5;

import java.util.Random;

public class ArrayGenerator {

    public int[] generateArray(int size) {
        //根据传入的参数生成一个数组
        int[] array = new int[size];
        //创建随机数生成器对象
        Random random = new Random();
        //对数组进行初始化
        for (int i = 0; i < size; i++) {
            array[i] = random.nextInt(10);
        }
        //返回数组
        return array;
    }

}

TaskManager(任务管理类,该类将帮助我们取消其他任务):

package day06.code_5;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinTask;

public class TaskManager {

    //任务集合
    private List<ForkJoinTask<Integer>> tasks;

    public TaskManager() {
        tasks = new ArrayList<>();
    }

    //向集合中添加任务
    public void addTask(ForkJoinTask<Integer> task) {
        tasks.add(task);
    }

    public void cancelTasks(ForkJoinTask<Integer> cancelTask) {
        //取消除传入的任务以外的其他所有任务
        for (ForkJoinTask<Integer> task : tasks) {
            if (task != cancelTask) {
                //取消任务
                task.cancel(true);
                //打印取消信息
                ((SearchNumberTask) task).writeCancelMessage();
            }
        }
    }
}

SearchNumberTask(搜索数字任务类):

package day06.code_5;


import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

public class SearchNumberTask extends RecursiveTask<Integer> {

    //待搜索的数组
    private int[] numbers;

    //搜索范围
    private int start, end;

    //目标数字
    private int number;

    //任务管理器
    private TaskManager manager;

    //未查询到目标数字时返回的常量
    private static final int NOT_FOUND = -1;

    //必要参数
    private static final long serialVersionUID = 1L;

    public SearchNumberTask(int[] numbers, int start, int end,
                            int number, TaskManager manager) {
        this.numbers = numbers;
        this.start = start;
        this.end = end;
        this.number = number;
        this.manager = manager;
    }

    @Override
    protected Integer compute() {
        //打印任务开始提示信息
        System.out.printf("Task: %d : %d\n", start, end);
        int ret;
        //如果搜索范围大于10
        if (end - start > 10) {
            //调用切割任务的方法
            ret = launchTasks();
        } else {
            //查找目标数字
            ret = lookForNumber();
        }
        //返回结果
        return ret;
    }

    private int launchTasks() {
        //切割任务
        int mid = (start + end) / 2;
        //创建两个新的任务在将其加入任务集合后执行
        SearchNumberTask task1 = new SearchNumberTask(numbers, start, mid, number, manager);
        SearchNumberTask task2 = new SearchNumberTask(numbers, mid, end, number, manager);
        manager.addTask(task1);
        manager.addTask(task2);
        task1.fork();
        task2.fork();
        //返回值
        int returnValue;
        //获取任务1的结果
        returnValue = task1.join();
        //如果查询到了就返回索引
        if (returnValue != -1) {
            return returnValue;
        }
        //否则返回任务2的结果
        return task2.join();

    }

    private int lookForNumber() {
        //遍历搜索范围内的数组
        for (int i = start; i < end; i++) {
            //如果是目标数字
            if (numbers[i] == number) {
                //打印查找成功提示语
                System.out.printf("Task: Number %d found in position %d\n",
                        number, i);
                //调用任务管理器的方法取消其他任务
                manager.cancelTasks(this);
                //返回目标数字的索引
                return i;
            }
            //休眠1秒
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        //没有查询到,返回常量
        return NOT_FOUND;
    }

    public void writeCancelMessage() {
        //打印任务取消的提示信息
        System.out.printf("Task: Cancelled task from %d to %d\n",
                start, end);
    }
}

main方法:

package day06.code_5;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;

public class Main {

    public static void main(String[] args) {
        //创建数组生成器
        ArrayGenerator generator = new ArrayGenerator();
        //得到一个容量为1000的数组
        int[] array = generator.generateArray(1000);
        //创建任务管理器
        TaskManager manager = new TaskManager();
        //创建线程池
        ForkJoinPool pool = new ForkJoinPool();
        //创建搜素数字任务
        SearchNumberTask task = new SearchNumberTask
                (array, 0, 1000, 5, manager);
        //将任务发送给线程池执行
        pool.execute(task);
        //关闭线程池
        pool.shutdown();
        //等待线程池将所有未取消的任务执行完毕
        try {
            pool.awaitTermination(1, TimeUnit.DAYS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        //打印程序结束信息
        System.out.println("Main: The program has finished");
    }

}