?
?
贝叶斯模型
class="模型" name="code">package bayes; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; public class Model { public Set<String> categorySet = new HashSet<String>(); public Set<String> keyWordsSet = new HashSet<String>(); public Map<String, Long> probabilityMap = new HashMap<String, Long>(); }
?
? ? 贝叶斯主类
package bayes; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileReader; import java.io.FileWriter; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; public class Bayes { /** * P(B|a1,a2,a3)= cBa1 * cBa2 * cBa3 / (cB * cB) / ( cBa1 * cBa2 * cba3 / (cB * cB) + cAa1 * cAa2 * cAa3 / (cA * cA) ) * * @param source * @param model * @return */ public static Map<String, Double> getValue(String source, Model model) { if (model != null && model.keyWordsSet != null && model.categorySet != null && source != null) { Set<String> keyWordSet = new HashSet<String>(); for (String key : model.keyWordsSet) { if (source.contains(key)) { keyWordSet.add(key); } } if (keyWordSet.size() > 0) { Map<String, Double> probabilityMap = new HashMap<String, Double>(); double sumProbalitity = 0; for (String category : model.categorySet) { double numerator = 1; double denominator = 1; int index = 0; for (String keyword : keyWordSet) { if (index > 0) { denominator = denominator * getProbalityValue(model, category); } numerator = numerator * getProbalityValue(model, category + "-" + keyword); index = index + 1; } double probalisty = Double.valueOf(numerator / denominator); sumProbalitity = sumProbalitity + probalisty; probabilityMap.put(category, probalisty); } Map<String, Double> rtnMap = new HashMap<String, Double>(); if (sumProbalitity > 0) { for (String category : model.categorySet) { rtnMap.put(category, Double.valueOf(probabilityMap.get(category) / sumProbalitity)); } } else { for (String category : model.categorySet) { rtnMap.put(category, Double.valueOf(1.0 / model.categorySet.size())); } } return rtnMap; } } return null; } public static long getProbalityValue(Model model, String key) { long rtn = 0; if (model.probabilityMap.containsKey(key)) { rtn = model.probabilityMap.get(key); } return rtn; } public static Model train(String[] categorys, String[][] data, String[] keyWords) { if (categorys != null && data != null && data.length == categorys.length && categorys.length > 1 && keyWords != null && keyWords.length > 1) { Model model = new Model(); model.categorySet.addAll(Arrays.asList(categorys)); model.keyWordsSet.addAll(Arrays.asList(keyWords)); for (int i = 0; i < categorys.length; i++) { calculateProbability(categorys[i], data[i], model); } return model; } else { System.out.println("data error!"); } return null; } private static void calculateProbability(String category, String[] categoryData, Model model) { for (String source : categoryData) { addCategoryKeywordCount(category, model); for (String keywork : model.keyWordsSet) { if (source.contains(keywork)) { addCategoryKeywordCount(keywork, model); addCategoryKeywordCount(category + "-" + keywork, model); } } } } private static void addCategoryKeywordCount(String key, Model model) { Long count = null; count = model.probabilityMap.get(key); if (count != null) { count = count + 1; } else { count = 1L; } model.probabilityMap.put(key, count); } public static void saveModel(String fileName, Model model) { try (BufferedWriter writer = new BufferedWriter(new FileWriter(fileName));) { for (String category : model.categorySet) { writer.write(category); writer.write(","); } writer.write("\n"); for (String keyword : model.keyWordsSet) { writer.write(keyword); writer.write(","); } writer.write("\n"); for (String key : model.probabilityMap.keySet()) { writer.write(key); writer.write(":"); writer.write(model.probabilityMap.get(key).toString()); writer.write("\n"); } writer.write("\n"); } catch (Exception e) { System.out.println("save Model error"); } } public static Model loadModel(String fileName) { Model model = new Model(); try (BufferedReader reader = new BufferedReader(new FileReader(fileName));) { String categoryLine = reader.readLine(); model.categorySet.addAll(getStringSet(categoryLine, ",")); String keyWorksLine = reader.readLine(); model.keyWordsSet.addAll(getStringSet(keyWorksLine, ",")); String probalilityLine = reader.readLine(); while (probalilityLine != null) { if (probalilityLine.trim().length() > 0) { String[] itemStr = (probalilityLine + ":").split(":"); if (itemStr.length == 2) { String key = itemStr[0]; Long probalility = Long.valueOf(itemStr[0]); model.probabilityMap.put(key, probalility); } else { System.out.println("Error model line:" + probalilityLine); } } probalilityLine = reader.readLine(); } } catch (Exception e) { System.out.println("load model error"); } return model; } public static Set<String> getStringSet(String sourceStr, String splitor) { Set<String> rtn = new HashSet<String>(); if (sourceStr != null && splitor != null) { String[] strs = sourceStr.split(splitor); if (strs != null && strs.length > 0) { for (String str : strs) { if (str != null) { rtn.add(str.trim()); } } } } return rtn; } }
?