class="java" name="code">
public static class SelectMaxProblem {
private final int[] numbers;
private final int start;
private final int end;
public final int size;
// constructors elided
public SelectMaxProblem(int[] numbers, int start, int end) {
this.numbers = numbers;
this.start = start;
this.end = end;
this.size = end - start;
}
public int solveSequentially() {
int max = Integer.MIN_VALUE;
for (int i=start; i<end; i++) {
int n = numbers[i];
if (n > max)
max = n;
}
return max;
}
public SelectMaxProblem subproblem(int subStart, int subEnd) {
return new SelectMaxProblem(numbers, start + subStart,
start + subEnd);
}
@Override
public String toString() {
return String.format("P{start:%2s, end:%2s}",start,end);
}
}
public static class MaxWithFJ extends RecursiveAction {
private final int threshold;
private final SelectMaxProblem problem;
public int result;
public MaxWithFJ(SelectMaxProblem problem, int threshold) {
this.problem = problem;
this.threshold = threshold;
}
protected void compute() {
String pre = problem.toString();
if (problem.size < threshold) {
result = problem.solveSequentially();
print(pre,"return result:"+result);
}else {
int midpoint = problem.size / 2;
SelectMaxProblem leftProblem = problem.subproblem(0, midpoint);
SelectMaxProblem rightProblem = problem.subproblem(midpoint + 1, problem.size);
MaxWithFJ left = new MaxWithFJ(leftProblem, threshold);
MaxWithFJ right = new MaxWithFJ(rightProblem, threshold);
print(pre,leftProblem + "|" + rightProblem);
left.fork();
right.fork();
print(pre,"fork");
print(pre,"begin left join");
left.join();
print(pre,"after left join.begin rigth join");
right.join();
print(pre,"join");
result = Math.max(left.result, right.result);
}
}
public static void main(String[] args) {
int size = 40;
int[] numbers = new int[size];
for (int i = 0; i < size; i++) {
numbers[i] = i;
}
SelectMaxProblem problem = new SelectMaxProblem(numbers,0,numbers.length);
int threshold = 10;
int nThreads = 2;
MaxWithFJ mfj = new MaxWithFJ(problem, threshold);
ForkJoinPool fjPool = new ForkJoinPool(nThreads);
fjPool.invoke(mfj);
int result = mfj.result;
print("main",result);
}
static void print(String pre,Object object) {
Thread thread = Thread.currentThread();
String msg = String.format("Thread[%s] %s > %s", thread.getId(), pre, object);
System.out.println(msg);
}
}
输出结果:
Thread[11] P{start: 0, end:40} > P{start: 0, end:20}|P{start:21, end:40}
Thread[11] P{start: 0, end:40} > fork
Thread[11] P{start: 0, end:40} > begin left join
Thread[12] P{start: 0, end:20} > P{start: 0, end:10}|P{start:11, end:20}
Thread[12] P{start: 0, end:20} > fork
Thread[12] P{start: 0, end:20} > begin left join
Thread[13] P{start:21, end:40} > P{start:21, end:30}|P{start:31, end:40}
Thread[13] P{start:21, end:40} > fork
Thread[13] P{start:21, end:40} > begin left join
Thread[12] P{start: 0, end:10} > P{start: 0, end: 5}|P{start: 6, end:10}
Thread[13] P{start:21, end:30} > return result:29
Thread[12] P{start: 0, end:10} > fork
Thread[13] P{start:21, end:40} > after left join.begin rigth join
Thread[12] P{start: 0, end:10} > begin left join
Thread[12] P{start: 0, end: 5} > return result:4
Thread[13] P{start:31, end:40} > return result:39
Thread[12] P{start: 0, end:10} > after left join.begin rigth join
Thread[13] P{start:21, end:40} > join
Thread[12] P{start: 6, end:10} > return result:9
Thread[13] P{start:11, end:20} > return result:19
Thread[12] P{start: 0, end:10} > join
Thread[12] P{start: 0, end:20} > after left join.begin rigth join
Thread[12] P{start: 0, end:20} > join
Thread[11] P{start: 0, end:40} > after left join.begin rigth join
Thread[11] P{start: 0, end:40} > join
Thread[1] main > 39
说明:
1.RecursiveAction#fork代表用新的
线程来执行
2.RecursiveAction#join会让当前线程开始等待.这里和普通的Thread#join的区别是,线程在调用join函数时,会去执行别的任务.可以从输出看到编号12和13的线程,join后再执行子任务.