Java 源码 - java.util.concurrent.RecursiveTask

286 阅读1分钟

简介

RecursiveTask 是ForkJoinTask的子类,适用于有返回值的计算

源码

因为源码比较简单, 直接附上所有的内容

public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
    private static final long serialVersionUID = 5232453952276485270L;

    /**
     * The result of the computation.
     */
    V result;

    /**
     * The main computation performed by this task.
     * @return the result of the computation
     */
    protected abstract V compute();

    public final V getRawResult() {
        return result;
    }

    protected final void setRawResult(V value) {
        result = value;
    }

    /**
     * Implements execution conventions for RecursiveTask.
     */
    protected final boolean exec() {
        result = compute();
        return true;
    }

}

注意事项

 * However, besides being a dumb way to compute Fibonacci functions
 * (there is a simple fast linear algorithm that you'd use in
 * practice), this is likely to perform poorly because the smallest
 * subtasks are too small to be worthwhile splitting up. Instead, as
 * is the case for nearly all fork/join applications, you'd pick some
 * minimum granularity size (for example 10 here) for which you always
 * sequentially solve rather than subdividing.

实例

Fibonacci

public class Fibonacci extends RecursiveTask<Integer> {

    public static void main(String[] args) {
        System.out.println(new Fibonacci(10).compute());
    }

    final int n;
    Fibonacci(int n) { this.n = n;}

    public Integer compute() {
        if(n <= 1)
            return n;
        Fibonacci f1 = new Fibonacci(n-1);
        f1.fork();

        Fibonacci f2 = new Fibonacci(n-2);
        f2.fork();

        return f2.join() + f1.join();
    }
}

ForkJoinPool

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class ForkJoinPoolDemo {

    public static void main(String[] args) 
                          throws ExecutionException, InterruptedException {
        ForkJoinPool forkJoinPool = new ForkJoinPool();

        Task countTask = new Task(1, 100);
        ForkJoinTask<Integer> result = forkJoinPool.submit(countTask);

        System.out.println("Result is: " + result.get());
        forkJoinPool.shutdown();
    }

    static class Task extends RecursiveTask<Integer> {
        private int start;
        private int end;
        private int mid;

        public Task(int start, int end) {
            this.start = start;
            this.end = end;
        }

        @Override
        protected Integer compute() {
            int sum = 0;
            if(end - start < 6) {
                //compute directly if task number is small
                for(int i = start; i <= end; i++)
                    sum += i;
                System.out.println(Thread.currentThread().getName() 
                                                 + " count sum: " + sum);
            } else {
                //divide the tasks
                mid = (end - start) / 2 + start;
                Task left = new Task(start, mid);
                Task right = new Task(mid + 1, end);

                //run tasks
                left.fork();
                right.fork();

                //get sums
                sum += left.join();
                sum += right.join();
            }
            return sum;
        }
    }
}