朴素贝叶斯算法java实现(多项式模型)
网上有很多对朴素贝叶斯算法的说明的文章,在对算法实现前,参考了一下几篇文章:
其中“带你搞懂朴素贝叶斯算法”在我看来比较容易理解,上面两篇比较详细,更深入。
算法java实现
第一步对训练集进行预处理,分词并计算词频,得到存储训练集的特征集合
/** * 所有训练集分词特征集合 * 第一个String代表分类标签,也就是存储该类别训练集的文件名 * 第二个String代表某条训练集的路径,这里存储的是该条语料的绝对路径 * Map<String, Integer>存储的是该条训练集的特征词和词频 * */ private static Map<String, Map<String, Map<String, Integer>>> allTrainFileSegsMap = new HashMap<String, Map<String, Map<String, Integer>>>(); /** * 放大因子 * 在计算中,因各个词的先验概率都比较小,我们乘以固定的值放大,便于计算 */ private static BigDecimal zoomFactor = new BigDecimal(10); /** * 对传入的训练集进行分词,获取训练集分词后的词和词频集合 * @param trainFilePath 训练集路径 */ public static void getFeatureClassForTrainText(String trainFilePath){ //通过将训练集路径字符串转变成抽象路径,创建一个File对象 File trainFileDirs = new File(trainFilePath); //获取该路径下的所有分类路径 File[] trainFileDirList = trainFileDirs.listFiles(); if (trainFileDirList == null){ System.out.println("训练数据集不存在"); } for (File trainFileDir : trainFileDirList){ //读取该分类下的所有训练文件 List<String> fileList = null; try { fileList = FileOptionUtil.readDirs(trainFileDir.getAbsolutePath()); if (fileList.size() != 0){ //遍历训练集目录数据,进行分词和类别标签处理 for(String filePath : fileList){ System.out.println("开始对此训练集进行分词处理:" + filePath); //分词处理,获取每条训练集文本的词和词频 //若知道文件编码的话,不要用下述的判断编码格式了,效率太低 // Map<String, Integer> contentSegs = IKWordSegmentation.segString(FileOptionUtil.readFile(filePath, FileOptionUtil.getCodeString(filePath))); Map<String, Integer> contentSegs = IKWordSegmentation.segString(FileOptionUtil.readFile(filePath, "gbk")); if (allTrainFileSegsMap.containsKey(trainFileDir.getName())){ Map<String, Map<String, Integer>> allSegsMap = allTrainFileSegsMap.get(trainFileDir.getName()); allSegsMap.put(filePath, contentSegs); allTrainFileSegsMap.put(trainFileDir.getName(), allSegsMap); } else { Map<String, Map<String, Integer>> allSegsMap = new HashMap<String, Map<String, Integer>>(); allSegsMap.put(filePath, contentSegs); allTrainFileSegsMap.put(trainFileDir.getName(), allSegsMap); } } } else { System.out.println("该分类下没有待训练语料"); } } catch (IOException e) { e.printStackTrace(); } } }
第二步计算类别的先验概率
/** * 计算类别C的先验概率 * 先验概率P(c)= 类c下单词总数/整个训练样本的单词总数 * @param category * @return 类C的先验概率 */ public static BigDecimal prioriProbability(String category){ BigDecimal categoryWordsCount = new BigDecimal(categoryWordCount(category)); BigDecimal allTrainFileWordCount = new BigDecimal(getAllTrainCategoryWordsCount()); return categoryWordsCount.divide(allTrainFileWordCount, 10, BigDecimal.ROUND_CEILING); }
第三步计算特征词的条件概率
/** * 多项式朴素贝叶斯类条件概率 * 类条件概率P(IK|c)=(类c下单词IK在各个文档中出现过的次数之和+1)/(类c下单词总数+|V|) * V是训练样本的单词表(即抽取单词,单词出现多次,只算一个), * |V|则表示训练样本包含多少种单词。 P(IK|c)可以看作是单词tk在证明d属于类c上提供了多大的证据, * 而P(c)则可以认为是类别c在整体上占多大比例(有多大可能性) * @param category * @param word * @return */ public static BigDecimal categoryConditionalProbability(String category, String word){ BigDecimal wordCount = new BigDecimal(wordInCategoryCount(word, category) + 1); BigDecimal categoryTrainFileWordCount = new BigDecimal(categoryWordCount(category) + getAllTrainCategoryWordCount()); return wordCount.divide(categoryTrainFileWordCount, 10, BigDecimal.ROUND_CEILING); }
第四步计算给定文本的分类结果
/** * 多项式朴素贝叶斯分类结果 * P(C_i|w_1,w_2...w_n) = P(w_1,w_2...w_n|C_i) * P(C_i) / P(w_1,w_2...w_n) * = P(w_1|C_i) * P(w_2|C_i)...P(w_n|C_i) * P(C_i) / (P(w_1) * P(w_2) ...P(w_n)) * @param words * @return */ public static Map<String, BigDecimal> classifyResult(Set<String> words){ Map<String, BigDecimal> resultMap = new HashMap<String, BigDecimal>(); //获取训练语料集所有的分类集合 Set<String> categorySet = allTrainFileSegsMap.keySet(); //循环计算每个类别的概率 for (String categorySetLabel : categorySet){ BigDecimal probability = new BigDecimal(1.0); for (String word : words){ probability = probability.multiply(categoryConditionalProbability(categorySetLabel, word)).multiply(zoomFactor); } resultMap.put(categorySetLabel, probability.multiply(prioriProbability(categorySetLabel))); } return resultMap; }
辅助函数
/** * 对分类结果进行比较,得出概率最大的类 * @param classifyResult * @return */ public static String getClassifyResultName(Map<String, BigDecimal> classifyResult){ String classifyName = ""; if (classifyResult.isEmpty()){ return classifyName; } BigDecimal result = new BigDecimal(0); Set<String> classifyResultSet = classifyResult.keySet(); for (String classifyResultSetString : classifyResultSet){ if (classifyResult.get(classifyResultSetString).compareTo(result) >= 1){ result = classifyResult.get(classifyResultSetString); classifyName = classifyResultSetString; } } return classifyName; } /** * 统计给定类别下的单词总数(带词频计算) * @param categoryLabel 指定类别参数 * @return */ public static Long categoryWordCount(String categoryLabel){ Long sum = 0L; Map<String, Map<String, Integer>> categoryWordMap = allTrainFileSegsMap.get(categoryLabel); if (categoryWordMap == null){ return sum; } Set<String> categoryWordMapKeySet = categoryWordMap.keySet(); for (String categoryLabelString : categoryWordMapKeySet){ Map<String, Integer> categoryWordMapDataMap = categoryWordMap.get(categoryLabelString); List<Map.Entry<String, Integer>> dataWordMapList = new ArrayList<Map.Entry<String, Integer>>(categoryWordMapDataMap.entrySet()); for (int i=0; i<dataWordMapList.size(); i++){ sum += dataWordMapList.get(i).getValue(); } } return sum; } /** * 获取训练样本所有词的总数(词总数计算是带上词频的,也就是可以重复算数) * @return */ public static Long getAllTrainCategoryWordsCount(){ Long sum = 0L; //获取所有分类 Set<String> categoryLabels = allTrainFileSegsMap.keySet(); //循环相加每个类下的词总数 for (String categoryLabel : categoryLabels){ sum += categoryWordCount(categoryLabel); } return sum; } /** * 获取训练样本下各个类别不重复词的总词数,区别于getAllTrainCategoryWordsCount()方法,此处计算不计算词频 * 备注:此处并不是严格意义上的进行全量词表生成后的计算,也就是加入类别1有"中国=6"、类别2有"中国=2",总词数算中国两次, * 也就是说,我们在计算的时候并没有生成全局词表(将所有词都作为出现一次) * @return */ public static Long getAllTrainCategoryWordCount(){ Long sum = 0L; //获取所有分类 Set<String> categoryLabels = allTrainFileSegsMap.keySet(); for (String cateGoryLabelsLabel : categoryLabels){ Map<String, Map<String, Integer>> categoryWordMap = allTrainFileSegsMap.get(cateGoryLabelsLabel); List<Map.Entry<String, Map<String, Integer>>> categoryWordMapList = new ArrayList<Map.Entry<String, Map<String, Integer>>>(categoryWordMap.entrySet()); for (int i=0; i<categoryWordMapList.size(); i++){ sum += categoryWordMapList.get(i).getValue().size(); } } return sum; } /** * 计算测试数据的每个单词在每个类下出现的总数 * @param word * @param categoryLabel * @return */ public static Long wordInCategoryCount(String word, String categoryLabel){ Long sum = 0L; Map<String, Map<String, Integer>> categoryWordMap = allTrainFileSegsMap.get(categoryLabel); Set<String> categoryWordMapKeySet = categoryWordMap.keySet(); for (String categoryWordMapKeySetFile : categoryWordMapKeySet){ Map<String, Integer> categoryWordMapDataMap = categoryWordMap.get(categoryWordMapKeySetFile); Integer value = categoryWordMapDataMap.get(word); if (value!=null && value>0){ sum += value; } } return sum; } /** * 获取所有分类类别 * @return */ public Set<String> getAllCategory(){ return allTrainFileSegsMap.keySet(); }
main函数测试
//main方法 public static void main(String[] args){ BayesNB.getFeatureClassForTrainText("/Users/zhouyh/work/yanfa/xunlianji/train/"); String s = "全国假日旅游部际协调会议的各成员单位和中央各有关部门围绕一个目标,积极配合,主动工作,抓得深入,抓得扎实。主要有以下几个特点:一是安全工作有部署有检查有跟踪。国务院安委会办公室节前深入部署全面检查,节中及时总结,下发关于黄金周后期安全工作的紧急通知;铁路、民航、交通等部门针对黄金周前后期旅客集中返程交通压力较大情况,及时调遣应急运力;质检总局进一步强化节日期间质量安全监管工作;旅游部门每日及时发布旅游信息通报,有效引导游客。二是各方面主动协调密切配合。各省区市加强了在安全事故问题上的协调与沟通,化解了一些跨省区矛盾和问题;铁道、民航部门准时准确报送信息;中宣部和中央文明办以黄金周旅游为载体," + "部署精神文明建设和践行社会主义荣辱观的宣传活动;中国气象局及时将黄金周每日气象分析送交各有关部门;公安部专门部署警力,为协调游客流动大的城市及景区做了大量工作;旅游部门密切配合有关部门做好各类事故处理和投诉调解工作。三是政府各部门的社会服务意识大为增强。外交部及其驻外领事馆及时提供境外安全信息为旅游者服务;中央电视台、地方电视台和各大媒体及各地方媒体提供的旅游信息十分丰富;气象信息服务充分具体;中消协提出多项旅游警示。各部门的密切配合和主动服务配合,确保了本次黄金周的顺利平稳运行。"; Set<String> words = IKWordSegmentation.segString(s).keySet(); Map<String, BigDecimal> resultMap = BayesNB.classifyResult(words); String category = BayesNB.getClassifyResultName(resultMap); System.out.println(category); }
经过上述步骤即可实现简单的多项式模型算法,有部分代码参考了网上的算法代码。