ForkJoinPool工作窃取算法

ForkJoinPool工作窃取算法

ForkJoinPool 是 Java 7 引入的专为分治任务设计的线程池,采用工作窃取(Work-Stealing)算法。本文深入解析其原理和应用。

核心设计

工作窃取算法

线程A(繁忙)          线程B(空闲)
Deque: [T1, T2, T3] Deque: []
↓ ↓
T1执行中 从线程A尾部窃取T3

特点

  • 每个线程维护一个双端队列(Deque)
  • 自己从队列头部取任务执行
  • 空闲线程从其他线程队列尾部窃取任务
  • 减少线程间的竞争(头部 vs 尾部)

ForkJoinTask

RecursiveAction(无返回值)

public class PrintTask extends RecursiveAction {
private static final int THRESHOLD = 10;
private int start;
private int end;

public PrintTask(int start, int end) {
this.start = start;
this.end = end;
}

@Override
protected void compute() {
if (end - start <= THRESHOLD) {
// 直接执行
for (int i = start; i < end; i++) {
System.out.println(Thread.currentThread().getName() + ": " + i);
}
} else {
// 拆分任务
int mid = (start + end) / 2;
PrintTask left = new PrintTask(start, mid);
PrintTask right = new PrintTask(mid, end);

invokeAll(left, right); // 并行执行子任务
}
}
}

// 使用
ForkJoinPool pool = new ForkJoinPool();
pool.submit(new PrintTask(0, 100));
pool.shutdown();

RecursiveTask(有返回值)

public class SumTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 10000;
private long[] array;
private int start;
private int end;

public SumTask(long[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}

@Override
protected Long compute() {
if (end - start <= THRESHOLD) {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}

int mid = (start + end) / 2;
SumTask left = new SumTask(array, start, mid);
SumTask right = new SumTask(array, mid, end);

left.fork(); // 异步执行左任务
long rightSum = right.compute(); // 同步执行右任务
long leftSum = left.join(); // 等待左任务结果

return leftSum + rightSum;
}
}

常用方法

方法 说明
fork() 异步执行任务
join() 等待任务完成并获取结果
invoke() 执行任务并等待结果(fork+join)
invokeAll() 并行执行多个任务
compute() 直接执行任务

最佳实践:fork + compute + join

// 好的模式
left.fork();
rightResult = right.compute();
leftResult = left.join();

// 避免两个fork
left.fork();
right.fork(); // 多了任务入队开销
leftResult = left.join();
rightResult = right.join();

与 Stream 的集成

parallelStream 的实现

List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5);

numbers.parallelStream()
.map(n -> n * n)
.collect(Collectors.toList());

底层:使用 ForkJoinPool.commonPool() 执行并行操作。

// 获取 common pool
ForkJoinPool commonPool = ForkJoinPool.commonPool();
// 默认线程数 = CPU核心数 - 1

自定义并行流线程池

// Java 8 中parallelStream无法指定线程池
// 需要使用CompletableFuture
ForkJoinPool customPool = new ForkJoinPool(4);
try {
return customPool.submit(() ->
numbers.parallelStream().map(...).collect(...)
).get();
} catch (Exception e) {
e.printStackTrace();
} finally {
customPool.shutdown();
}

注意事项

1. 任务粒度

// 错误:任务太小,拆分开销 > 执行开销
if (end - start <= 1) { ... } // 粒度太细

// 正确:根据实际测试调整阈值
if (end - start <= 10000) { ... } // 合理的阈值

2. 避免阻塞操作

// 错误:在ForkJoinTask中阻塞
public class BadTask extends RecursiveTask<String> {
protected String compute() {
return callExternalAPI(); // 阻塞HTTP调用!
}
}

// 正确:ForkJoinPool适合计算密集型任务

3. 避免共享可变状态

// 错误:并发修改共享变量
private int sum = 0;

protected Integer compute() {
sum += localSum; // 非线程安全
return sum;
}

// 正确:每个任务返回局部结果,最后合并
protected Integer compute() {
int localSum = calculate();
return localSum;
}

性能对比

public class PerformanceTest {
public static void main(String[] args) {
long[] array = new long[100_000_000];
Arrays.fill(array, 1);

// 单线程
long start1 = System.currentTimeMillis();
long sum1 = Arrays.stream(array).sum();
System.out.println("单线程: " + (System.currentTimeMillis() - start1) + "ms");

// ForkJoin
long start2 = System.currentTimeMillis();
ForkJoinPool pool = new ForkJoinPool();
long sum2 = pool.invoke(new SumTask(array, 0, array.length));
System.out.println("ForkJoin: " + (System.currentTimeMillis() - start2) + "ms");

// parallelStream
long start3 = System.currentTimeMillis();
long sum3 = Arrays.stream(array).parallel().sum();
System.out.println("parallelStream: " + (System.currentTimeMillis() - start3) + "ms");
}
}

结果(8核 CPU):

  • 单线程:~150ms
  • ForkJoin:~30ms(5倍提升)
  • parallelStream:~35ms

适用场景

场景 是否适合
大规模数组求和/排序 适合
递归树遍历 适合
MapReduce 计算 适合
阻塞 IO 操作 不适合
任务依赖复杂 不适合
数据量小 不适合

总结

ForkJoinPool 通过工作窃取算法,在计算密集型任务中实现了优异的并行性能:

  1. 分解:大任务拆分为小任务
  2. 执行:线程从自己的队列头部取任务
  3. 窃取:空闲线程从其他队列尾部偷任务
  4. 合并:汇总子任务结果

正确使用 ForkJoinPool,可以充分利用多核 CPU 的计算能力。


   转载规则


《ForkJoinPool工作窃取算法》 小乐 采用 知识共享署名 4.0 国际许可协议 进行许可。
  目录