一. 概率论基础
1. 条件概率公式:
2. 全概率公式:
3. 由条件概率公式和全概率公式可以导出贝叶斯公式
二. 文本分类
要计算一篇文章D所属的类别c(D),相当于计算生成D的可能性最大的类别,即:
其中P(D)与C无关,故
三. 朴素贝叶斯分类模型
朴素贝叶斯假设:在给定类别C的条件下,所有属性Di相互独立,即,
根据朴素贝叶斯假设,可得
其中,
:类别c中的训练文本数
:总训练文本数
:单词di在类别c中出现的次数
综上可得,
四. 具体代码( 源代码)
程序采用java语言进行编写,运用搜狗语料库进行训练。具体程序代码如下:
Main.java——主程序,负责读取待分类文章以及调用分类器
package classifierDemo;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
public class Main {
public static void main(String[] args) throws IOException {
String article = "";
// 读取文章
String path = "article.txt";
InputStreamReader is = new InputStreamReader(new FileInputStream(path),
"gbk");
BufferedReader br = new BufferedReader(is);
String temp = br.readLine();
while (temp != null) {
article += temp;
temp = br.readLine();
}
br.close();
System.out.println(article + "\n");
// 对文章进行分类
TrainDataManager train = new TrainDataManager();
train.execute(article);
}
}
TrainData.java——分类器,对输入文本进行分类
package classifierDemo;
import java.io.File;
import java.io.IOException;
import java.util.Vector;
public class TrainDataManager {
private static final String dirName = "trainingData/Sample"; // 训练集所在目录
CountDirectory countDir = new CountDirectory();
private int zoomFactor = 5; // 放大倍数
/**
* 计算先验概率 p(ci)=某个类别文章数/训练文本总数
*
* @param className
* 类别名称
* @return
*/
public double priorProbability(String className) {
double probability = 0.0;
probability = (double) countDir.countClass(className)
/ countDir.countSum();
return probability;
}
public void execute(String article) throws IOException {
// 进行分词
Vector<String> strs = ChineseSpliter.splitWords(article);
File dir = new File(dirName);
File[] files = dir.listFiles(); // 目录下的所有文件
String className;
double countc;
double product = 1;
Vector<Double> probability = new Vector<Double>();
double temp;
// 计算文本属于每个类别的概率
for (File f : files) {
className = f.getName();
countc = countDir.countClass(className);
// 计算文本中某个词属于特定类别中的概率
for (String word : strs) {
temp = (countDir.countWordInClass(word, className) + 1)
/ countc * zoomFactor;// 避免所得结果过小,对结果进行放大
product *= temp;
}
probability.add(priorProbability(className) * product);
product = 1;
}
double max = 0;
int maxId = 0;
for (int i = 0; i < files.length; i++) {
if (max < probability.get(i)) {
max = probability.get(i);
maxId = i;
}
}
System.out.println("文章所属分类为:" + files[maxId].getName());
}
}
CountDiretory.java——用于计算训练集中的各种频次
package classifierDemo;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
/**
* 计算各种频次
*
* @author Administrator
*
*/
public class CountDirectory {
private static final String dirName = "trainingData/Sample"; // 训练集所在目录
public int countSum() {
File dir = new File(dirName);
File[] files = dir.listFiles(); // 目录下的所有文件
String subName = ""; // 子目录的路径名称
int sum = 0; // 训练集中所有类别的总文本数
// 计算所有文件的总数
for (int i = 0; i < files.length; i++) {
subName = files[i].getName();
sum += countClass(subName);
}
return sum;
}
/**
* 用于计算某个类别下的文章总数
*
* @param className
* 类别名称
* @return 给定类别目录下的文件总数
*/
public int countClass(String className) {
String classPath = dirName + "/" + className;
File subDir = new File(classPath);// 子目录
File[] subFiles = subDir.listFiles(); // 子目录下的所有文件
return subFiles.length;
}
/**
* 计算某个类别中包含给定词的文章数目
*
* @param word
* 给定的词
* @param className
* 类别名称
* @return className中包含word的文章数
*/
public int countWordInClass(String word, String className)
throws IOException {
int count = 0;// 总数
String classPath = dirName + "/" + className;
File subDir = new File(classPath);
File[] subFiles = subDir.listFiles();
String filePath = "";
// 计算word在各篇文章中出现的次数
for (int i = 0; i < subFiles.length; i++) {
// 读取文章
filePath = subFiles[i].getAbsolutePath();
InputStreamReader is = new InputStreamReader(new FileInputStream(
filePath), "gbk");
BufferedReader br = new BufferedReader(is);
String temp = br.readLine();
String line = "";
while (temp != null) {
line += temp;
temp = br.readLine();
}
br.close();
if (line.contains(word))
count++;
}
return count;
}
}
ChineseSpliter.java——中文分词器,对输入字串进行中文分词
package classifierDemo;
import java.io.IOException;
import java.util.Vector;
import jeasy.analysis.MMAnalyzer;
/**
* 中文分词器 对输入文本进行分词处理
*
*/
public class ChineseSpliter {
private static String splitToken = "|"; // 定义用于分隔的标记
/**
* 对给定文本进行中文分词
*
* @param article
* 待分词的文章
* @return 分词后的结果向量
*/
public static Vector<String> splitWords(String article) {
String result = null;
MMAnalyzer analyzer = new MMAnalyzer();
try {
result = analyzer.segment(article, splitToken);
} catch (IOException e) {
e.printStackTrace();
}
Vector<String> vector = stringToVector(result);
// 去除停用词
StopWordsHandler stopWords = new StopWordsHandler();
vector = stopWords.DropStopWords(vector);
return vector;
}
/**
* 将分词结果字符串转化为向量形式
*
* @param str
* 字符串形式
* @return 字符串对应的向量形式
*/
public static Vector<String> stringToVector(String str) {
int index;
Vector<String> vector = new Vector<String>();
index = str.indexOf(splitToken);
String temp;
while (index != -1) {
temp = str.substring(0, index);
vector.add(temp);
str = str.substring(index + 1, str.length());
index = str.indexOf(splitToken);
}
return vector;
}
}
StopWordsHandler.java——停用词处理
package classifierDemo;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Vector;
/**
*
* 停用词处理器
*
*/
public class StopWordsHandler {
private List<String> stopWordsList;// 常用停用词
private String path = "chineseStopWords.txt";
public StopWordsHandler() {
try {
stopWordsList = readStopWords();
} catch (Exception e) {
System.out.println(e);
}
}
public StopWordsHandler(String path) throws FileNotFoundException,
IOException {
this.path = path;
stopWordsList = readStopWords();
}
/**
* 读取停用词表
*
* @return 所有停用词
*/
public List<String> readStopWords() throws FileNotFoundException,
IOException {
List<String> stopWordsList1 = new ArrayList<String>();
InputStreamReader is = new InputStreamReader(new FileInputStream(path),
"gbk");
BufferedReader br = new BufferedReader(is);
String line = br.readLine();
while (line != null) {
stopWordsList1.add(line);
line = br.readLine();
}
br.close();
return stopWordsList1;
}
/**
* 判断是否为停用词
*
* @param word
* 给定的文本
* @return 是否为停用词
*/
public boolean IsStopWord(String word) {
for (int i = 0; i < stopWordsList.size(); ++i) {
if (word.equalsIgnoreCase(stopWordsList.get(i)))
return true;
}
return false;
}
/**
* 去掉停用词
*
* @param oldWords
* 分词后的文本
* @return 去停用词后结果
*/
public Vector<String> DropStopWords(Vector<String> oldWords) {
Vector<String> v1 = new Vector<String>();
for (int i = 0; i < oldWords.size(); ++i) {
if (IsStopWord(oldWords.elementAt(i)) == false) {// 不是停用词
v1.add(oldWords.elementAt(i));
}
}
return v1;
}
}
五. 参考文献: