很早就知道JDK7的ForkAndJoin框架了,一直没有学习。看了官网WordCount的
例子,
http://www.oracle.com/technetwork/articles/java/fork-join-422606.html,
自己写一个归并排序练练手。(代码中增加的打印语句完全是为了方便
理解)
class="java" name="code">import java.util.*;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
public class MergeSortTask extends RecursiveTask<int[]> {
private static final String ROOT = "0";
private static final String LEFT = "0";
private static final String RIGHT = "1";
private final int[] array;
private final String serial;
public MergeSortTask(int[] array, String serial) {
this.array = array;
this.serial = serial;
}
@Override
protected int[] compute() {
int length = array.length;
if (length <= 3) {
Arrays.sort(array);
return array;
}
String leftSerial = getLeftSerial();
String rightSerial = getRightSerial();
int[] leftArray = Arrays.copyOfRange(array, 0, length / 2);
int[] rightArray = Arrays.copyOfRange(array, length / 2, length);
MergeSortTask leftTask = new MergeSortTask(leftArray, leftSerial);
MergeSortTask rightTask = new MergeSortTask(rightArray, rightSerial);
System.out.println(new MidArray(leftSerial, leftArray));
System.out.println(new MidArray(rightSerial, rightArray));
int[] leftResult = leftTask.fork().join();
int[] rightResult = rightTask.fork().join();
System.out.println(new MidResult(new MidArray(leftSerial, leftResult)));
System.out.println(new MidResult(new MidArray(rightSerial, rightResult)));
return merge(leftResult, rightResult);
}
private String getLeftSerial() {
return serial + "-" + LEFT;
}
private String getRightSerial() {
return serial + "-" + RIGHT;
}
private int[] merge(int[] leftArray, int[] rightArray) {
int leftArrayLength = leftArray.length;
int rightArrayLength = rightArray.length;
int[] resultArray = new int[leftArrayLength + rightArrayLength];
int resultIndex = 0;
int leftIndex = 0;
int rightIndex = 0;
while (leftIndex < leftArrayLength && rightIndex < rightArrayLength) {
if (leftArray[leftIndex] <= rightArray[rightIndex]) {
resultArray[resultIndex++] = leftArray[leftIndex++];
} else {
resultArray[resultIndex++] = rightArray[rightIndex++];
}
}
//copy left rest
if (leftIndex < leftArrayLength) {
System.arraycopy(leftArray, leftIndex, resultArray,
resultIndex, leftArrayLength - leftIndex);
}
//copy right rest
if (rightIndex < rightArrayLength) {
System.arraycopy(rightArray, rightIndex, resultArray,
resultIndex, rightArrayLength - rightIndex);
}
return resultArray;
}
private static int[] buildTestArray(int size) {
int[] testArray = new int[size];
Random rand = new Random();
for (int i = 0; i < size; i++) {
testArray[i] = rand.nextInt(100);
}
return testArray;
}
public static void main(String[] args) {
int[] testArray = buildTestArray(10);
System.out.println(Arrays.toString(testArray));
MergeSortTask mergeSorter = new MergeSortTask(testArray, ROOT);
ForkJoinPool pool = new ForkJoinPool();
System.out.println(Arrays.toString(pool.invoke(mergeSorter)));
}
}
//辅助类
import java.util.Arrays;
public class MidArray {
private final String serial;
private final int[] result;
private static final String INDENT = " ";
public MidArray(String serial, int[] result) {
this.serial = serial;
this.result = result;
}
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < serial.split("-").length; i++) {
sb.append(INDENT);
}
sb.append(serial).append(", is: ").append(Arrays.toString(result));
return sb.toString();
}
}
public class MidResult {
private MidArray midArray;
public MidResult(MidArray midArray) {
this.midArray = midArray;
}
public String toString() {
return midArray.toString() + " [Result]";
}
}
一次运行的输出结果:
[71, 8, 25, 9, 4, 41, 38, 87, 45, 96]
0-0, is: [71, 8, 25, 9, 4]
0-1, is: [41, 38, 87, 45, 96]
0-0-0, is: [71, 8]
0-0-1, is: [25, 9, 4]
0-0-0, is: [8, 71] [Result]
0-0-1, is: [4, 9, 25] [Result]
0-1-0, is: [41, 38]
0-1-1, is: [87, 45, 96]
0-1-0, is: [38, 41] [Result]
0-1-1, is: [45, 87, 96] [Result]
0-0, is: [4, 8, 9, 25, 71] [Result]
0-1, is: [38, 41, 45, 87, 96] [Result]
[4, 8, 9, 25, 38, 41, 45, 71, 87, 96]