java实现朴素贝叶斯算法_JAVA_编程开发_程序员俱乐部

中国优秀的程序员网站程序员频道CXYCLUB技术地图
热搜:
更多>>
 
您所在的位置: 程序员俱乐部 > 编程开发 > JAVA > java实现朴素贝叶斯算法

java实现朴素贝叶斯算法

 2019/2/3 18:32:47  zhoupinheng  程序员俱乐部  我要评论(0)
  • 摘要:贝叶斯模型packagebayes;importjava.util.HashMap;importjava.util.HashSet;importjava.util.Map;importjava.util.Set;publicclassModel{publicSet<String>categorySet=newHashSet<String>();publicSet<String>keyWordsSet=newHashSet<String>()
  • 标签:实现 Java 算法

?

?

贝叶斯模型

class="模型" name="code">package bayes;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class Model {
	public Set<String> categorySet = new HashSet<String>();
	public Set<String> keyWordsSet = new HashSet<String>();
	public Map<String, Long> probabilityMap = new HashMap<String, Long>();
}

?

? ? 贝叶斯主类

package bayes;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class Bayes {

	/**
	 * P(B|a1,a2,a3)= cBa1 * cBa2 * cBa3 / (cB * cB) / ( cBa1 * cBa2 * cba3 / (cB * cB) + cAa1 * cAa2 * cAa3 / (cA * cA) )
	 * 
	 * @param source
	 * @param model
	 * @return
	 */
	public static Map<String, Double> getValue(String source, Model model) {

		if (model != null && model.keyWordsSet != null && model.categorySet != null && source != null) {
			Set<String> keyWordSet = new HashSet<String>();

			for (String key : model.keyWordsSet) {
				if (source.contains(key)) {
					keyWordSet.add(key);
				}
			}

			if (keyWordSet.size() > 0) {
				Map<String, Double> probabilityMap = new HashMap<String, Double>();

				double sumProbalitity = 0;
				for (String category : model.categorySet) {
					double numerator = 1;
					double denominator = 1;
					int index = 0;
					for (String keyword : keyWordSet) {
						if (index > 0) {
							denominator = denominator * getProbalityValue(model, category);
						}
						numerator = numerator * getProbalityValue(model, category + "-" + keyword);
						index = index + 1;
					}

					double probalisty = Double.valueOf(numerator / denominator);
					sumProbalitity = sumProbalitity + probalisty;
					probabilityMap.put(category, probalisty);
				}

				Map<String, Double> rtnMap = new HashMap<String, Double>();
				if (sumProbalitity > 0) {
					for (String category : model.categorySet) {
						rtnMap.put(category, Double.valueOf(probabilityMap.get(category) / sumProbalitity));
					}
				} else {
					for (String category : model.categorySet) {
						rtnMap.put(category, Double.valueOf(1.0 / model.categorySet.size()));
					}
				}

				return rtnMap;
			}

		}

		return null;
	}

	public static long getProbalityValue(Model model, String key) {
		long rtn = 0;
		if (model.probabilityMap.containsKey(key)) {
			rtn = model.probabilityMap.get(key);
		}

		return rtn;
	}

	public static Model train(String[] categorys, String[][] data, String[] keyWords) {

		if (categorys != null && data != null && data.length == categorys.length && categorys.length > 1 && keyWords != null && keyWords.length > 1) {

			Model model = new Model();
			model.categorySet.addAll(Arrays.asList(categorys));
			model.keyWordsSet.addAll(Arrays.asList(keyWords));

			for (int i = 0; i < categorys.length; i++) {
				calculateProbability(categorys[i], data[i], model);
			}

			return model;
		} else {
			System.out.println("data error!");
		}

		return null;
	}

	private static void calculateProbability(String category, String[] categoryData, Model model) {
		for (String source : categoryData) {
			addCategoryKeywordCount(category, model);
			for (String keywork : model.keyWordsSet) {
				if (source.contains(keywork)) {
					addCategoryKeywordCount(keywork, model);
					addCategoryKeywordCount(category + "-" + keywork, model);
				}
			}
		}

	}

	private static void addCategoryKeywordCount(String key, Model model) {
		Long count = null;
		count = model.probabilityMap.get(key);
		if (count != null) {
			count = count + 1;
		} else {
			count = 1L;
		}
		model.probabilityMap.put(key, count);

	}

	public static void saveModel(String fileName, Model model) {
		try (BufferedWriter writer = new BufferedWriter(new FileWriter(fileName));) {
			for (String category : model.categorySet) {
				writer.write(category);
				writer.write(",");
			}
			writer.write("\n");

			for (String keyword : model.keyWordsSet) {
				writer.write(keyword);
				writer.write(",");
			}
			writer.write("\n");

			for (String key : model.probabilityMap.keySet()) {
				writer.write(key);
				writer.write(":");
				writer.write(model.probabilityMap.get(key).toString());
				writer.write("\n");
			}
			writer.write("\n");
		} catch (Exception e) {
			System.out.println("save Model error");
		}
	}

	public static Model loadModel(String fileName) {

		Model model = new Model();

		try (BufferedReader reader = new BufferedReader(new FileReader(fileName));) {

			String categoryLine = reader.readLine();
			model.categorySet.addAll(getStringSet(categoryLine, ","));
			String keyWorksLine = reader.readLine();
			model.keyWordsSet.addAll(getStringSet(keyWorksLine, ","));
			String probalilityLine = reader.readLine();
			while (probalilityLine != null) {
				if (probalilityLine.trim().length() > 0) {
					String[] itemStr = (probalilityLine + ":").split(":");
					if (itemStr.length == 2) {
						String key = itemStr[0];
						Long probalility = Long.valueOf(itemStr[0]);
						model.probabilityMap.put(key, probalility);
					} else {
						System.out.println("Error model line:" + probalilityLine);
					}
				}

				probalilityLine = reader.readLine();
			}

		} catch (Exception e) {
			System.out.println("load model error");
		}

		return model;
	}

	public static Set<String> getStringSet(String sourceStr, String splitor) {
		Set<String> rtn = new HashSet<String>();

		if (sourceStr != null && splitor != null) {
			String[] strs = sourceStr.split(splitor);
			if (strs != null && strs.length > 0) {
				for (String str : strs) {
					if (str != null) {
						rtn.add(str.trim());
					}
				}
			}

		}

		return rtn;
	}

}

?

上一篇: 使用简单的Java代码在SAP C4C里创建销售订单 下一篇: 没有下一篇了!
发表评论
用户名: 匿名