ForkJoinPool工作窃取算法
ForkJoinPool 是 Java 7 引入的专为分治任务设计的线程池,采用工作窃取(Work-Stealing)算法。本文深入解析其原理和应用。
核心设计
工作窃取算法
线程A(繁忙) 线程B(空闲) Deque: [T1, T2, T3] Deque: [] ↓ ↓ T1执行中 从线程A尾部窃取T3
|
特点:
- 每个线程维护一个双端队列(Deque)
- 自己从队列头部取任务执行
- 空闲线程从其他线程队列尾部窃取任务
- 减少线程间的竞争(头部 vs 尾部)
ForkJoinTask
RecursiveAction(无返回值)
public class PrintTask extends RecursiveAction { private static final int THRESHOLD = 10; private int start; private int end; public PrintTask(int start, int end) { this.start = start; this.end = end; } @Override protected void compute() { if (end - start <= THRESHOLD) { for (int i = start; i < end; i++) { System.out.println(Thread.currentThread().getName() + ": " + i); } } else { int mid = (start + end) / 2; PrintTask left = new PrintTask(start, mid); PrintTask right = new PrintTask(mid, end); invokeAll(left, right); } } }
ForkJoinPool pool = new ForkJoinPool(); pool.submit(new PrintTask(0, 100)); pool.shutdown();
|
RecursiveTask(有返回值)
public class SumTask extends RecursiveTask<Long> { private static final int THRESHOLD = 10000; private long[] array; private int start; private int end; public SumTask(long[] array, int start, int end) { this.array = array; this.start = start; this.end = end; } @Override protected Long compute() { if (end - start <= THRESHOLD) { long sum = 0; for (int i = start; i < end; i++) { sum += array[i]; } return sum; } int mid = (start + end) / 2; SumTask left = new SumTask(array, start, mid); SumTask right = new SumTask(array, mid, end); left.fork(); long rightSum = right.compute(); long leftSum = left.join(); return leftSum + rightSum; } }
|
常用方法
| 方法 |
说明 |
| fork() |
异步执行任务 |
| join() |
等待任务完成并获取结果 |
| invoke() |
执行任务并等待结果(fork+join) |
| invokeAll() |
并行执行多个任务 |
| compute() |
直接执行任务 |
最佳实践:fork + compute + join
left.fork(); rightResult = right.compute(); leftResult = left.join();
left.fork(); right.fork(); leftResult = left.join(); rightResult = right.join();
|
与 Stream 的集成
parallelStream 的实现
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5);
numbers.parallelStream() .map(n -> n * n) .collect(Collectors.toList());
|
底层:使用 ForkJoinPool.commonPool() 执行并行操作。
ForkJoinPool commonPool = ForkJoinPool.commonPool();
|
自定义并行流线程池
ForkJoinPool customPool = new ForkJoinPool(4); try { return customPool.submit(() -> numbers.parallelStream().map(...).collect(...) ).get(); } catch (Exception e) { e.printStackTrace(); } finally { customPool.shutdown(); }
|
注意事项
1. 任务粒度
if (end - start <= 1) { ... }
if (end - start <= 10000) { ... }
|
2. 避免阻塞操作
public class BadTask extends RecursiveTask<String> { protected String compute() { return callExternalAPI(); } }
|
3. 避免共享可变状态
private int sum = 0;
protected Integer compute() { sum += localSum; return sum; }
protected Integer compute() { int localSum = calculate(); return localSum; }
|
性能对比
public class PerformanceTest { public static void main(String[] args) { long[] array = new long[100_000_000]; Arrays.fill(array, 1); long start1 = System.currentTimeMillis(); long sum1 = Arrays.stream(array).sum(); System.out.println("单线程: " + (System.currentTimeMillis() - start1) + "ms"); long start2 = System.currentTimeMillis(); ForkJoinPool pool = new ForkJoinPool(); long sum2 = pool.invoke(new SumTask(array, 0, array.length)); System.out.println("ForkJoin: " + (System.currentTimeMillis() - start2) + "ms"); long start3 = System.currentTimeMillis(); long sum3 = Arrays.stream(array).parallel().sum(); System.out.println("parallelStream: " + (System.currentTimeMillis() - start3) + "ms"); } }
|
结果(8核 CPU):
- 单线程:~150ms
- ForkJoin:~30ms(5倍提升)
- parallelStream:~35ms
适用场景
| 场景 |
是否适合 |
| 大规模数组求和/排序 |
适合 |
| 递归树遍历 |
适合 |
| MapReduce 计算 |
适合 |
| 阻塞 IO 操作 |
不适合 |
| 任务依赖复杂 |
不适合 |
| 数据量小 |
不适合 |
总结
ForkJoinPool 通过工作窃取算法,在计算密集型任务中实现了优异的并行性能:
- 分解:大任务拆分为小任务
- 执行:线程从自己的队列头部取任务
- 窃取:空闲线程从其他队列尾部偷任务
- 合并:汇总子任务结果
正确使用 ForkJoinPool,可以充分利用多核 CPU 的计算能力。