Executor框架相比于传统的并发系统基础实现具有很多的优势。传统做法是实现一个Runnable接口的类,然后使用该类的对象来直接创建Thread实例。
这种做法有一些问题,特别是当你启动太多线程的时候,你可能降低了整个系统的性能。
?
?本章提供了几个例子来解释Executor的使用。例子代码有点长,不过相对还是蛮简单的清晰的。
?
K邻近算法是一个用于监督分类的简单机器学习算法。算法包含以下主要部分:
当要对一个样本进行归类时,算法计算该样本和训练数据集里所有样本的距离。然后再取距离最小的的 k 个样本,这 k 个样本中,哪个标签数最多,那么这个标签就赋给要归类的那个样本。根据第一章得出的经验,我们从算法的串行版本开始,然后从串行版本演变到并行版本。
?
?
class="java">public class KnnClassifier { private List<? extends Sample> dataSet; private int k; public KnnClassifier(List<? extends Sample> dataSet, int k) { this.dataSet = dataSet; this.k = k; } public String classify(Sample example) { Distance[] distances = new Distance[dataSet.size()]; int index = 0; // 计算新样本和训练数据集中各样本之间的距离 for (Sample localExample : dataSet) { distances[index] = new Distance(); distances[index].setIndex(index); distances[index].setDistance (EuclideanDistanceCalculator.calculate(localExample, example)); index++; } // 对计算得到的距离排序以便获取K个最近距离的样本 Arrays.sort(distances); Map<String, Integer> results = new HashMap<>(); for (int i = 0; i < k; i++) { Sample localExample = dataSet.get(distances[i].getIndex()); String tag = localExample.getTag(); results.merge(tag, 1, (a, b) -> a + b); } // 返回最近k个样本总数最多的那个标签 return Collections.max(results.entrySet(), Map.Entry.comparingByValue()).getKey(); } }?
?
?
// 该类用来计算两个样本的距离 public class EuclideanDistanceCalculator { public static double calculate (Sample example1, Sample example2) { double ret=0.0d; double[] data1=example1.getExample(); double[] data2=example2.getExample(); if (data1.length!=data2.length) { throw new IllegalArgumentException ("Vector doesn't have the same length"); } for (int i=0; i<data1.length; i++) { ret+=Math.pow(data1[i]-data2[i], 2); } return Math.sqrt(ret); } }?
?
如果你分析以上的算法的并行版本,你会发现有两点你可以用并行来实现:
在细颗粒度并发版本中,我们为每一个计算输入样本和训练数据集中样本的距离创建一个任务。由此可见,所谓的细颗粒度就是我们创建了很多的任务。
?
?
?
public class KnnClassifierParallelIndividual { private List<? extends Sample> dataSet; private int k; private ThreadPoolExecutor executor; private int numThreads; private boolean parallelSort; public KnnClassifierParallelIndividual(List<? extends Sample> dataSet, int k, int factor, boolean parallelSort) { this.dataSet = dataSet; this.k = k; // 动态获取运行此程序的处理器或核的数量来决定线程池中线程的数量 numThreads = factor * (Runtime.getRuntime().availableProcessors()); executor = (ThreadPoolExecutor)Executors.newFixedThreadPool(numThreads); this.parallelSort = parallelSort; } /** * 因为我们为每个距离计算创建了一个任务,因此主线程需要等待所有任务完成后才能继续, * 我们使用 CountDownLatch 这个类来同步所有任务的完成, * 我们用任务总数也就是数据集中样本的总数来初始化 CountDownLatch, * 每个任务完成后调用 countDown() 方法 */ public String classify(Sample example) throws Exception { Distance[] distances = new Distance[dataSet.size()]; CountDownLatch endController = new CountDownLatch(dataSet.size()); int index = 0; for (Sample localExample : dataSet) { IndividualDistanceTask task = new IndividualDistanceTask(distances, index, localExample, example, endController); executor.execute(task); index++; } endController.await(); if (parallelSort) { Arrays.parallelSort(distances); } else { Arrays.sort(distances); } Map<String, Integer> results = new HashMap<>(); for (int i = 0; i < k; i++) { Sample localExample = dataSet.get(distances[i].getIndex()); String tag = localExample.getTag(); results.merge(tag, 1, (a, b) -> a + b); } // 返回最近k个样本总数最多的那个标签 return Collections.max(results.entrySet(), Map.Entry.comparingByValue()).getKey(); } public void destroy() { executor.shutdown(); } }?
?
?
public class IndividualDistanceTask implements Runnable { private Distance[] distances; private int index; private Sample localExample; private Sample example; private CountDownLatch endController; public IndividualDistanceTask(Distance[] distances, int index, Sample localExample, Sample example, CountDownLatch endController) { this.distances = distances; this.index = index; this.localExample = localExample; this.example = example; this.endController = endController; } public void run() { distances[index] = new Distance(); distances[index].setIndex(index); distances[index].setDistance (EuclideanDistanceCalculator.calculate(localExample, example)); // 任务完成,调用CountDownLatch的countDown() endController.countDown(); } }?
?
细颗粒度版本的问题是创建了太多的任务,粗颗粒度版本中,我们让每一个任务处理数据集的一个子集,这样避免创建太多的任务。
public class KnnClassifierParallelIndividual { private List<? extends Sample> dataSet; private int k; private ThreadPoolExecutor executor; private int numThreads; private boolean parallelSort; public KnnClassifierParallelIndividual(List<? extends Sample> dataSet, int k, int factor, boolean parallelSort) { this.dataSet = dataSet; this.k = k; // 动态获取运行此程序的处理器或核的数量来决定线程池中线程的数量 numThreads = factor * (Runtime.getRuntime().availableProcessors()); executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numThreads); this.parallelSort = parallelSort; } /** * 因为我们为每个距离计算创建了一个任务,因此主线程需要等待所有任务完成后才能继续, * 我们使用 CountDownLatch 这个类来同步所有任务的完成, * 我们用任务总数也就是数据集中样本的总数来初始化 CountDownLatch,每个任务完成后调用 countDown() 方法 */ public String classify(Sample example) throws Exception { Distance distances[] = new Distance[dataSet.size()]; CountDownLatch endController = new CountDownLatch(numThreads); int length = dataSet.size() / numThreads; int startIndex = 0, endIndex = length; for (int i = 0; i < numThreads; i++) { GroupDistanceTask task = new GroupDistanceTask(distances, startIndex, endIndex, dataSet, example, endController); startIndex = endIndex; if (i < numThreads - 2) { endIndex = endIndex + length; } else { endIndex = dataSet.size(); } executor.execute(task); } endController.await(); if (parallelSort) { Arrays.parallelSort(distances); } else { Arrays.sort(distances); } Map<String, Integer> results = new HashMap<>(); for (int i = 0; i < k; i++) { Sample localExample = dataSet.get(distances[i].getIndex()); String tag = localExample.getTag(); results.merge(tag, 1, (a, b) -> a + b); } // 返回最近k个样本总数最多的那个标签 return Collections.max(results.entrySet(), Map.Entry.comparingByValue()).getKey(); } public void destroy() { executor.shutdown(); } }
?
public class GroupDistanceTask implements Runnable { private Distance[] distances; private int startIndex, endIndex; private Sample example; private List<? extends Sample> dataSet; private CountDownLatch endController; public GroupDistanceTask(Distance[] distances, int startIndex, int endIndex, List<? extends Sample> dataSet, Sample example, CountDownLatch endController) { this.distances = distances; this.startIndex = startIndex; this.endIndex = endIndex; this.example = example; this.dataSet = dataSet; this.endController = endController; } public void run() { for (int index = startIndex; index < endIndex; index++) { Sample localExample = dataSet.get(index); distances[index] = new Distance(); distances[index].setIndex(index); distances[index].setDistance(EuclideanDistanceCalculator .calculate(localExample, example)); } endController.countDown(); } }
?
在这个例子中:
跟上述例子一样,我们先从串行版本入手然后过渡到并行版本。
?
程序包括以下三个主要部分:
以下是串行版本指令类部分的代码
//*****************串行版本指令类部分*****************// // 指令的抽象类 public abstract class Command { protected String[] command; public Command (String [] command) { this.command=command; } public abstract String execute (); } // 对应Query请求的指令类 public class QueryCommand extends Command { public QueryCommand(String [] command) { super(command); } public String execute() { WDIDAO dao=WDIDAO.getDAO(); if (command.length==3) { return dao.query(command[1], command[2]); } else if (command.length==4) { try { return dao.query(command[1], command[2], Short.parseShort(command[3])); } catch (Exception e) { return "ERROR;Bad Command"; } } else { return "ERROR;Bad Command"; } } } //对应Report请求的指令类 public class ReportCommand extends Command { public ReportCommand(String [] command) { super(command); } public String execute() { WDIDAO dao=WDIDAO.getDAO(); return dao.report(command[1]); } } //对应Stop请求的指令类 public class StopCommand extends Command { public StopCommand(String [] command) { super(command); } public String execute() { return "Server stopped"; } } //此类处理一些服务器不支持的请求 public class ErrorCommand extends Command { public ErrorCommand(String [] command) { super(command); } public String execute() { return "Unknown command: "+command[0]; } }
?
以下是服务器部分的代码
?
public class SerialServer { public static void main(String[] args) throws IOException { boolean stopServer = false; System.out.println("Initialization completed."); try (ServerSocket serverSocket = new ServerSocket(Constants.SERIAL_PORT)) { // 不断循环,直到stopServer被设置为false do { try (Socket clientSocket = serverSocket.accept(); PrintWriter out = new PrintWriter (clientSocket.getOutputStream(), true); BufferedReader in = new BufferedReader(new InputStreamReader(clientSocket.getInputStream()));) { String line = in.readLine(); Command command; String[] commandData = line.split(";"); System.out.println("Command: " + commandData[0]); switch (commandData[0]) { case "q": System.out.println("Query"); command = new QueryCommand(commandData); break; case "r": System.out.println("Report"); command = new ReportCommand(commandData); break; case "z": System.out.println("Stop"); command = new StopCommand(commandData); stopServer = true; break; default: System.out.println("Error"); command = new ErrorCommand(commandData); } String response = command.execute(); System.out.println(response); } catch (IOException e) { e.printStackTrace(); } } while (!stopServer); } catch (Exception e) { e.printStackTrace(); } } }
?
?
众所周知以上串行版本存在着严重的性能问题,服务器一次只能处理一个请求,其余的请求需要等待。并行版本中,我们将改为主线程接收请求,然后为每个请求创建一个任务,并交由线程池中的线程执行。
?
以下是并行版本指令类部分的代码,大部分代码和并行版本一样,除了Stop指令类。类名我们改为以"Concurrent"开始
?
//*****************串行版本指令类部分*****************// // 指令的抽象类 public abstract class Command { protected String[] command; public Command (String [] command) { this.command=command; } public abstract String execute (); } // 对应Query请求的指令类 public class ConcurrentQueryCommand extends Command { public ConcurrentQueryCommand(String [] command) { super(command); } public String execute() { WDIDAO dao=WDIDAO.getDAO(); if (command.length==3) { return dao.query(command[1], command[2]); } else if (command.length==4) { try { return dao.query(command[1], command[2], Short.parseShort(command[3])); } catch (Exception e) { return "ERROR;Bad Command"; } } else { return "ERROR;Bad Command"; } } } //对应Report请求的指令类 public class ConcurrentReportCommand extends Command { public ConcurrentReportCommand(String [] command) { super(command); } public String execute() { WDIDAO dao=WDIDAO.getDAO(); return dao.report(command[1]); } } //对应Stop请求的指令类 public class ConcurrentStopCommand extends Command { public ConcurrentStopCommand(String [] command) { super(command); } public String execute() { ConcurrentServer.shutdown(); return "Server stopped"; } } //此类处理一些服务器不支持的请求 public class ConcurrentErrorCommand extends Command { public ConcurrentErrorCommand(String [] command) { super(command); } public String execute() { return "Unknown command: "+command[0]; } } //并行版本中新增了服务器状态查询指令 public class ConcurrentStatusCommand extends Command { public ConcurrentStatusCommand (String[] command) { super(command); } @Override public String execute() { StringBuilder sb=new StringBuilder(); ThreadPoolExecutor executor=ConcurrentServer.getExecutor(); sb.append("Server Status;"); sb.append("Actived Threads: "); sb.append(String.valueOf(executor.getActiveCount())); sb.append(";"); sb.append("Maximum Pool Size: "); sb.append(String.valueOf(executor.getMaximumPoolSize())); sb.append(";"); sb.append("Core Pool Size: "); sb.append(String.valueOf(executor.getCorePoolSize())); sb.append(";"); sb.append("Pool Size: "); sb.append(String.valueOf(executor.getPoolSize())); sb.append(";"); sb.append("Largest Pool Size: "); sb.append(String.valueOf(executor.getLargestPoolSize())); sb.append(";"); sb.append("Completed Task Count: "); sb.append(String.valueOf(executor.getCompletedTaskCount())); sb.append(";"); sb.append("Task Count: "); sb.append(String.valueOf(executor.getTaskCount())); sb.append(";"); sb.append("Queue Size: "); sb.append(String.valueOf(executor.getQueue().size())); sb.append(";"); return sb.toString(); } }
?
?
?以下是服务器部分代码和实现Runnable接口的RequestTask类
public class ConcurrentServer { private static ThreadPoolExecutor executor; private static ServerSocket serverSocket; private static volatile boolean stopped = false; public static void main(String[] args) throws InterruptedException { serverSocket = null; executor = (ThreadPoolExecutor) Executors.newFixedThreadPool (Runtime.getRuntime().availableProcessors()); System.out.println("Initialization completed."); serverSocket = new ServerSocket(Constants.CONCURRENT_PORT); do { try { Socket clientSocket = serverSocket.accept(); RequestTask task = new RequestTask(clientSocket); executor.execute(task); } catch (IOException e) { e.printStackTrace(); } } while (!stopped); executor.awaitTermination(1, TimeUnit.DAYS); System.out.println("Shutting down cache"); System.out.println("Cache ok"); System.out.println("Main server thread ended"); } public static void shutdown() { stopped = true; System.out.println("Shutting down the server..."); System.out.println("Shutting down executor"); executor.shutdown(); System.out.println("Executor ok"); System.out.println("Closing socket"); try { serverSocket.close(); System.out.println("Socket ok"); } catch (IOException e) { e.printStackTrace(); } System.out.println("Shutting down logger"); System.out.println("Logger ok"); } public static ThreadPoolExecutor getExecutor() { return executor; } } public class RequestTask implements Runnable { private Socket clientSocket; public RequestTask(Socket clientSocket) { this.clientSocket = clientSocket; } public void run() { try (PrintWriter out = new PrintWriter(clientSocket.getOutputStream(), true); BufferedReader in = new BufferedReader(new InputStreamReader( clientSocket.getInputStream()));) { String line = in.readLine(); Command command; String[] commandData = line.split(";"); System.out.println("Command: " + commandData[0]); switch (commandData[0]) { case "q": System.err.println("Query"); command = new ConcurrentQueryCommand(commandData); break; case "r": System.err.println("Report"); command = new ConcurrentReportCommand(commandData); break; case "s": System.err.println("Status"); command = new ConcurrentStatusCommand(commandData); break; case "z": System.err.println("Stop"); command = new ConcurrentStopCommand(commandData); break; default: System.err.println("Error"); command = new ConcurrentErrorCommand(commandData); break; } ret = command.execute(); System.out.println(ret); out.println(ret); } catch (Exception e) { e.printStackTrace(); } finally { try { clientSocket.close(); } catch (IOException e) { e.printStackTrace(); } } } }
?
Executors 类提供了另外一些方法创建 ThreadPoolExecutor 对象。这些方法包括:
Java中支持两种并行数据结构:
有些数据结构实现了两种行为,有些数据结构则只实现一种行为。通常,阻塞数据结构同时也实现具有非阻塞行为的方法,但是非阻塞线程没有实现阻塞行为的方法。
?
阻塞操作的方法有:
非阻塞操作的方法有: