🍴 ForkJoinPool工作窃取算法:分治思想的极致体现!

51 阅读11分钟

一、什么是ForkJoin?🤔

核心思想:分而治之(Divide and Conquer)

一句话总结:
把大任务**拆分(Fork)成小任务,并行执行后再合并(Join)**结果!

生活比喻:

吃一头大象🐘怎么办?

传统方式(单线程):
从头吃到尾,吃到天荒地老...😭

ForkJoin方式:
1. Fork(拆分):
   把大象切成1000块
   
2. Parallel(并行):
   召集1000个人,每人吃一块
   
3. Join(合并):
   大家一起吃完!✅

经典应用场景

1. 计算密集型任务
   ├─ 大数组求和
   ├─ 大数据排序
   ├─ 斐波那契数列
   └─ 矩阵运算

2. 数据处理
   ├─ 文件批量处理
   ├─ 图片批量压缩
   ├─ 日志批量分析
   └─ 数据库批量操作

3. 树/图遍历
   ├─ 文件树遍历
   ├─ DOM树处理
   └─ 图搜索算法

二、ForkJoin框架核心组件 🏗️

1️⃣ ForkJoinPool - 线程池

// 创建ForkJoinPool
ForkJoinPool pool = new ForkJoinPool();

// 指定并行度(线程数)
ForkJoinPool pool = new ForkJoinPool(4);  // 4个工作线程

// 使用通用池(推荐)
ForkJoinPool commonPool = ForkJoinPool.commonPool();

通用池的并行度:

// 默认并行度 = CPU核心数 - 1
int parallelism = Runtime.getRuntime().availableProcessors() - 1;

// 例如:8核CPU → 7个工作线程

2️⃣ ForkJoinTask - 任务抽象类

ForkJoinTask(抽象类)
    │
    ├─ RecursiveTask<V>      (有返回值)
    │     例如:计算求和,返回结果
    │
    └─ RecursiveAction       (无返回值)
          例如:打印输出,不需要返回

3️⃣ 工作窃取队列(Deque)

每个线程都有自己的双端队列:

Worker Thread 1      Worker Thread 2      Worker Thread 3
      │                    │                    │
      ↓                    ↓                    ↓
   ┌─────┐              ┌─────┐              ┌─────┐
   │Task │              │Task │              │Task │
   ├─────┤              ├─────┤              ├─────┤
   │Task │              │Task │              │     │ ← 空闲!
   ├─────┤              ├─────┤              └─────┘
   │Task │              │Task │                  ↑
   └─────┘              └─────┘                  │
                            ↓                    │
                        窃取!← ─ ─ ─ ─ ─ ─ ─ ─ ─┘

三、工作窃取算法(Work-Stealing)🥷

核心思想

闲着也是闲着,不如帮别人干活!

线程1:任务很多,忙不过来 😰
线程2:任务做完了,闲着     😴
      ↓
线程2:偷走线程1的一个任务来做!🥷

详细流程

1. 初始分配:
   ┌────────┐  ┌────────┐  ┌────────┐
   │Thread-1│  │Thread-2│  │Thread-3│
   └────────┘  └────────┘  └────────┘
   自己的任务    自己的任务    自己的任务

2. Thread-2先完成,队列空了:
   ┌────────┐  ┌────────┐  ┌────────┐
   │Thread-1│  │Thread-2│  │Thread-3│
   └────────┘  └────────┘  └────────┘
   [Task][Task]  [ 空 ] ← 闲着!  [Task]

3. Thread-2去偷Thread-1的任务:
   ┌────────┐  ┌────────┐  ┌────────┐
   │Thread-1│  │Thread-2│  │Thread-3│
   └────────┘  └────────┘  └────────┘
   [Task]       [Task] ← 偷来的!    [Task]

4. Thread-2又做完了,去偷Thread-3的:
   ┌────────┐  ┌────────┐  ┌────────┐
   │Thread-1│  │Thread-2│  │Thread-3│
   └────────┘  └────────┘  └────────┘
   [Task]       [Task] ← 又偷来的!  [ 空 ]

5. 大家一起完成所有任务!✅

为什么用双端队列(Deque)?

双端队列:两头都能进出

自己的线程:
从头部(head)取任务  LIFO (后进先出)
├─ pop()  ← 拿最新的任务
│
偷窃者线程:
从尾部(tail)偷任务  FIFO (先进先出)
└─ steal() ← 拿最老的任务

好处:
1. 减少冲突:自己拿新的,别人偷旧的,碰撞概率低
2. 局部性:新任务可能用到刚才的数据,缓存命中率高

图示:

线程自己的队列:
  head                           tail
   ↓                              ↓
  [Task4][Task3][Task2][Task1]
   ↑                              ↑
   │                              │
自己pop() ← LIFO              别人steal() ← FIFO

自己:拿Task4(最新的)
别人:偷Task1(最老的)
→ 不会冲突!✨

四、代码实战:手把手教你写ForkJoin 💻

例子1:大数组求和(RecursiveTask)

import java.util.concurrent.*;

public class SumTask extends RecursiveTask<Long> {
    
    // 阈值:小于这个值就不再拆分
    private static final int THRESHOLD = 10_000;
    
    private long[] array;
    private int start;
    private int end;
    
    public SumTask(long[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }
    
    @Override
    protected Long compute() {
        int length = end - start;
        
        // 1. 如果任务足够小,直接计算
        if (length <= THRESHOLD) {
            return computeDirectly();
        }
        
        // 2. 任务太大,拆分成两个子任务
        int middle = start + length / 2;
        
        SumTask leftTask = new SumTask(array, start, middle);
        SumTask rightTask = new SumTask(array, middle, end);
        
        // 3. Fork:异步执行左边的任务
        leftTask.fork();  // 放入队列
        
        // 4. 当前线程执行右边的任务
        long rightResult = rightTask.compute();  // 直接计算
        
        // 5. Join:等待左边的结果
        long leftResult = leftTask.join();  // 阻塞等待
        
        // 6. 合并结果
        return leftResult + rightResult;
    }
    
    private long computeDirectly() {
        long sum = 0;
        for (int i = start; i < end; i++) {
            sum += array[i];
        }
        return sum;
    }
    
    // 使用示例
    public static void main(String[] args) {
        // 创建100万个数的数组
        long[] array = new long[1_000_000];
        for (int i = 0; i < array.length; i++) {
            array[i] = i + 1;
        }
        
        // 创建任务
        SumTask task = new SumTask(array, 0, array.length);
        
        // 提交到ForkJoinPool
        ForkJoinPool pool = ForkJoinPool.commonPool();
        long result = pool.invoke(task);  // 阻塞等待结果
        
        System.out.println("结果:" + result);
        // 输出:500000500000
    }
}

执行流程图:

                 [0, 1000000] (100万)
                      │
                   Fork & Join
                      ↓
         ┌────────────┴────────────┐
         │                         │
    [0, 500000]               [500000, 1000000]
    (50万)                     (50万)
         │                         │
      Fork & Join              Fork & Join
         ↓                         ↓
    ┌────┴────┐              ┌────┴────┐
    │         │              │         │
[0,250000] [250000,500000] [500000,750000] [750000,1000000]
(25万)      (25万)          (25万)          (25万)
  ...        ...            ...             ...
继续拆分,直到 ≤ 10000

最终:100个小任务,并行计算

例子2:归并排序(RecursiveAction)

import java.util.Arrays;
import java.util.concurrent.*;

public class MergeSortTask extends RecursiveAction {
    
    private static final int THRESHOLD = 100;  // 阈值
    
    private int[] array;
    private int start;
    private int end;
    
    public MergeSortTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }
    
    @Override
    protected void compute() {
        int length = end - start;
        
        // 1. 小数组直接排序
        if (length <= THRESHOLD) {
            Arrays.sort(array, start, end);
            return;
        }
        
        // 2. 拆分成两个子任务
        int middle = start + length / 2;
        
        MergeSortTask leftTask = new MergeSortTask(array, start, middle);
        MergeSortTask rightTask = new MergeSortTask(array, middle, end);
        
        // 3. 并行执行两个子任务
        invokeAll(leftTask, rightTask);  // 等待两个都完成
        
        // 4. 合并结果
        merge(start, middle, end);
    }
    
    private void merge(int start, int middle, int end) {
        int[] temp = new int[end - start];
        int i = start, j = middle, k = 0;
        
        while (i < middle && j < end) {
            temp[k++] = array[i] <= array[j] ? array[i++] : array[j++];
        }
        
        while (i < middle) temp[k++] = array[i++];
        while (j < end) temp[k++] = array[j++];
        
        System.arraycopy(temp, 0, array, start, temp.length);
    }
    
    // 使用示例
    public static void main(String[] args) {
        int[] array = new Random().ints(1_000_000, 0, 100000).toArray();
        
        long start = System.currentTimeMillis();
        
        ForkJoinPool pool = ForkJoinPool.commonPool();
        pool.invoke(new MergeSortTask(array, 0, array.length));
        
        long end = System.currentTimeMillis();
        System.out.println("耗时:" + (end - start) + "ms");
        
        // 验证排序正确性
        System.out.println("已排序:" + isSorted(array));
    }
}

例子3:斐波那契数列(经典递归)

public class FibonacciTask extends RecursiveTask<Integer> {
    
    private final int n;
    
    public FibonacciTask(int n) {
        this.n = n;
    }
    
    @Override
    protected Integer compute() {
        // 递归终止条件
        if (n <= 1) {
            return n;
        }
        
        // Fork两个子任务
        FibonacciTask f1 = new FibonacciTask(n - 1);
        FibonacciTask f2 = new FibonacciTask(n - 2);
        
        f1.fork();  // 异步执行f1
        
        int result2 = f2.compute();  // 当前线程执行f2
        int result1 = f1.join();     // 等待f1结果
        
        return result1 + result2;
    }
    
    // 使用
    public static void main(String[] args) {
        ForkJoinPool pool = ForkJoinPool.commonPool();
        int result = pool.invoke(new FibonacciTask(40));
        System.out.println("Fibonacci(40) = " + result);
    }
}

五、Fork/Join vs 普通线程池 ⚔️

性能对比

// 测试:100万个数求和

// 方案1:单线程
long sum1 = 0;
for (long num : array) {
    sum1 += num;
}
// 耗时:~15ms

// 方案2:普通线程池
ExecutorService executor = Executors.newFixedThreadPool(8);
List<Future<Long>> futures = new ArrayList<>();
int chunkSize = array.length / 8;
for (int i = 0; i < 8; i++) {
    int start = i * chunkSize;
    int end = (i == 7) ? array.length : (i + 1) * chunkSize;
    futures.add(executor.submit(() -> {
        long sum = 0;
        for (int j = start; j < end; j++) {
            sum += array[j];
        }
        return sum;
    }));
}
long sum2 = 0;
for (Future<Long> f : futures) {
    sum2 += f.get();
}
// 耗时:~5ms(提速3倍)

// 方案3:ForkJoinPool
ForkJoinPool pool = ForkJoinPool.commonPool();
long sum3 = pool.invoke(new SumTask(array, 0, array.length));
// 耗时:~3ms(提速5倍!)✨

对比表

特性普通线程池ForkJoinPool
适用场景独立任务可拆分的任务
任务粒度粗粒度细粒度
负载均衡静态分配动态窃取 ✨
队列单个共享队列每线程一个队列
性能中等高(递归任务)
线程利用率中等高(工作窃取)✨
编程复杂度简单中等

何时用ForkJoin?

✅ 适合ForkJoin:
1. 任务可以拆分成子任务
2. 子任务之间相互独立
3. 任务是CPU密集型
4. 数据量大,拆分后并行效果好

例如:
- 大数组/集合计算
- 递归算法
- 树/图遍历
- 批量数据处理

❌ 不适合ForkJoin:
1. IO密集型任务(会阻塞线程)
2. 任务不可拆分
3. 任务有依赖关系
4. 任务数量少

例如:
- 数据库查询
- 网络请求
- 文件IO
- 3个独立任务

六、核心API详解 📚

fork() - 异步执行

task.fork();  // 将任务放入当前线程的队列,异步执行

join() - 等待结果

result = task.join();  // 阻塞等待任务完成,返回结果

invoke() - 同步执行

result = task.invoke();  // 立即执行并等待结果(= compute() + join())

invokeAll() - 批量执行

// 执行多个任务,等待全部完成
invokeAll(task1, task2, task3);

// 等价于:
task1.fork();
task2.fork();
task3.fork();
task1.join();
task2.join();
task3.join();

最佳实践

// ❌ 错误:都用fork,最后都join
leftTask.fork();
rightTask.fork();
long left = leftTask.join();
long right = rightTask.join();
// 问题:当前线程也空闲了,浪费!

// ✅ 正确:fork一个,当前线程执行另一个
leftTask.fork();  // 左边异步
long right = rightTask.compute();  // 右边当前线程执行
long left = leftTask.join();  // 等待左边
// 好处:充分利用当前线程!✨

七、阈值(Threshold)设置技巧 🎯

阈值的作用

阈值太小:
- 拆分次数多
- 任务创建开销大
- 工作窃取频繁
- 性能下降 📉

阈值太大:
- 拆分次数少
- 并行度低
- 某些线程闲置
- 性能下降 📉

阈值合适:
- 拆分适中
- 充分并行
- 开销可控
- 性能最优 🚀

经验值

// 计算密集型:
int threshold = totalSize / (parallelism * 4);
// 例如:100万数据,8核CPU
// threshold = 1000000 / (8 * 4) = 31250

// 通用建议:
// - 小任务(< 1万):1000 - 5000
// - 中任务(1万 - 100万):10000 - 50000
// - 大任务(> 100万):50000 - 100000

动态阈值

public class AdaptiveTask extends RecursiveTask<Long> {
    
    private static final int MIN_THRESHOLD = 1000;
    private static final int MAX_THRESHOLD = 100000;
    
    @Override
    protected Long compute() {
        int length = end - start;
        
        // 动态计算阈值
        int threshold = Math.min(MAX_THRESHOLD,
            Math.max(MIN_THRESHOLD, 
                array.length / (ForkJoinPool.getCommonPoolParallelism() * 4)));
        
        if (length <= threshold) {
            return computeDirectly();
        }
        
        // ... fork & join
    }
}

八、并行流(Parallel Stream)的秘密 🌊

底层就是ForkJoinPool

// Java 8的并行流
long sum = Arrays.stream(array)
                 .parallel()  // 并行化
                 .sum();

// 等价于:
ForkJoinPool.commonPool().invoke(new SumTask(...));

并行流 vs 手写ForkJoin

// 方案1:并行流(简单)
long sum = Arrays.stream(array).parallel().sum();

// 方案2:手写ForkJoin(灵活)
ForkJoinPool pool = new ForkJoinPool(16);  // 自定义线程数
long sum = pool.invoke(new SumTask(array, 0, array.length));

对比:

特性Parallel Stream手写ForkJoin
代码量1行 ✅50行+
易用性超简单 ✅需要理解
灵活性有限完全控制 ✅
自定义线程池不支持*支持 ✅
性能调优有限完全控制 ✅

*注:通过trick可以实现,但不推荐:

// ⚠️ 不推荐的hack方式
ForkJoinPool customPool = new ForkJoinPool(16);
customPool.submit(() -> 
    Arrays.stream(array).parallel().sum()
).get();

九、常见坑点 ⚠️

坑1:阈值设置不当

// ❌ 错误:阈值太小
private static final int THRESHOLD = 1;
// 100万数据会拆分100万次!创建100万个任务!💥

// ✅ 正确:合理阈值
private static final int THRESHOLD = 10000;
// 100万数据拆分100次,创建100个任务 ✅

坑2:任务不均匀

// ❌ 错误:左重右轻
int middle = start + 1;  // 左边几乎全部,右边只有1个
// 导致负载不均衡

// ✅ 正确:平均拆分
int middle = start + (end - start) / 2;

坑3:使用了阻塞操作

// ❌ 错误:在compute()中阻塞
@Override
protected Long compute() {
    // 阻塞操作会占用工作线程!
    Thread.sleep(1000);  // ❌
    userService.queryUser();  // ❌ IO操作
    return result;
}

// ✅ 正确:ForkJoin只用于CPU密集型

坑4:共享变量

// ❌ 错误:多个任务修改共享变量
private static long sum = 0;  // 共享变量

@Override
protected Long compute() {
    sum += xxx;  // ❌ 并发问题!
    return null;
}

// ✅ 正确:通过返回值传递
@Override
protected Long compute() {
    return xxx;  // ✅ 无副作用
}

十、面试应答模板 🎤

面试官:说说ForkJoinPool的工作窃取算法原理?

你的回答:

ForkJoinPool是基于分治思想工作窃取算法的线程池,专门用于处理可拆分的递归任务。

核心原理:

1️⃣ 任务拆分(Fork):

  • 将大任务递归拆分成小任务
  • 拆分到一定阈值后直接计算
  • 使用RecursiveTask(有返回值)或RecursiveAction(无返回值)

2️⃣ 工作窃取(Work-Stealing):

  • 每个工作线程有自己的双端队列(Deque)
  • 线程从队列头部取自己的任务(LIFO)
  • 空闲线程从别人队列尾部偷任务(FIFO)
  • 这样减少了竞争,提高了并行度

3️⃣ 结果合并(Join):

  • 等待子任务完成
  • 合并子任务的结果
  • 返回最终结果

为什么用双端队列?

  • 自己从头部拿(新任务),利用CPU缓存局部性
  • 别人从尾部偷(老任务),减少冲突
  • 两端操作,碰撞概率低

适用场景:

  • CPU密集型任务
  • 可递归拆分的任务(如归并排序、大数组计算)
  • 任务粒度可控

举例: 100万个数求和,拆分成100个子任务,8个线程并行计算。某个线程提前完成后,会去偷其他线程的任务,充分利用CPU,性能比普通线程池提升30%+。

十一、总结 🎯

ForkJoin核心要点:

思想:分治(Divide and Conquer)
  ├─ Fork:拆分任务
  ├─ Parallel:并行执行
  └─ Join:合并结果

算法:工作窃取(Work-Stealing)
  ├─ 每线程一个双端队列
  ├─ 自己从头部拿(LIFO)
  └─ 别人从尾部偷(FIFO)

优势:
  ✅ 负载均衡(动态窃取)
  ✅ 充分利用CPU
  ✅ 减少线程竞争
  ✅ 性能优异

适用:
  ✅ CPU密集型
  ✅ 可拆分任务
  ✅ 递归算法
  ✅ 大数据处理

不适用:
  ❌ IO密集型
  ❌ 阻塞操作
  ❌ 任务不可拆分

记忆口诀:
大任务拆小块,
并行执行效率高,
工作窃取防闲置,
双端队列减冲突,
CPU密集最适合!🎵

核心要点:

  • ✅ 分治思想:Fork + Compute + Join
  • ✅ 工作窃取:双端队列 + 头尾分离
  • ✅ 适合CPU密集型递归任务
  • ✅ 阈值设置很重要
  • ✅ 并行流底层用的就是ForkJoinPool