贝叶斯文本分类 java实现
昨天实现了一个基于贝叶斯定理的的文本分类,贝叶斯定理假设特征属性(在文本中就是词汇)对待分类项的影响都是独立的,道理比较简单,在中文分类系统中,分类的准确性与分词系统的好坏有很大的关系,这段代码也是试验不同分词系统才顺手写的一个。
试验数据用的sogou实验室的文本分类样本,一共分为9个类别,每个类别文件夹下大约有2000篇文章。由于文本数据量确实较大,所以得想办法让每次训练的结果都能保存起来,以便于下次直接使用,我这里使用序列化的方式保存在硬盘。
训练代码如下:
1 /** 2 * 训练器 3 * 4 * <a href="http://my.oschina.net/arthor" target="_blank" rel="nofollow">@author</a> duyf 5 * 6 */ 7 class Train implements Serializable { 8 9 /** 10 * 11 */ 12 private static final long serialVersionUID = 1L; 13 14 public final static String SERIALIZABLE_PATH = "D:\\workspace\\Test\\SogouC.mini\\Sample\\Train.ser"; 15 // 训练集的位置 16 private String trainPath = "D:\\workspace\\Test\\SogouC.mini\\Sample"; 17 18 // 类别序号对应的实际名称 19 private Map<String, String> classMap = new HashMap<String, String>(); 20 21 // 类别对应的txt文本数 22 private Map<String, Integer> classP = new ConcurrentHashMap<String, Integer>(); 23 24 // 所有文本数 25 private AtomicInteger actCount = new AtomicInteger(0); 26 27 28 29 // 每个类别对应的词典和频数 30 private Map<String, Map<String, Double>> classWordMap = new ConcurrentHashMap<String, Map<String, Double>>(); 31 32 // 分词器 33 private transient Participle participle; 34 35 private static Train trainInstance = new Train(); 36 37 public static Train getInstance() { 38 trainInstance = new Train(); 39 40 // 读取序列化在硬盘的本类对象 41 FileInputStream fis; 42 try { 43 File f = new File(SERIALIZABLE_PATH); 44 if (f.length() != 0) { 45 fis = new FileInputStream(SERIALIZABLE_PATH); 46 ObjectInputStream oos = new ObjectInputStream(fis); 47 trainInstance = (Train) oos.readObject(); 48 trainInstance.participle = new IkParticiple(); 49 } else { 50 trainInstance = new Train(); 51 } 52 } catch (Exception e) { 53 e.printStackTrace(); 54 } 55 56 return trainInstance; 57 } 58 59 private Train() { 60 this.participle = new IkParticiple(); 61 } 62 63 public String readtxt(String path) { 64 BufferedReader br = null; 65 StringBuilder str = null; 66 try { 67 br = new BufferedReader(new FileReader(path)); 68 69 str = new StringBuilder(); 70 71 String r = br.readLine(); 72 73 while (r != null) { 74 str.append(r); 75 r = br.readLine(); 76 77 } 78 79 return str.toString(); 80 } catch (IOException ex) { 81 ex.printStackTrace(); 82 } finally { 83 if (br != null) { 84 try { 85 br.close(); 86 } catch (IOException e) { 87 e.printStackTrace(); 88 } 89 } 90 str = null; 91 br = null; 92 } 93 94 return ""; 95 } 96 97 /** 98 * 训练数据 99 */ 100 public void realTrain() { 101 // 初始化 102 classMap = new HashMap<String, String>(); 103 classP = new HashMap<String, Integer>(); 104 actCount.set(0); 105 classWordMap = new HashMap<String, Map<String, Double>>(); 106 107 // classMap.put("C000007", "汽车"); 108 classMap.put("C000008", "财经"); 109 classMap.put("C000010", "IT"); 110 classMap.put("C000013", "健康"); 111 classMap.put("C000014", "体育"); 112 classMap.put("C000016", "旅游"); 113 classMap.put("C000020", "教育"); 114 classMap.put("C000022", "招聘"); 115 classMap.put("C000023", "文化"); 116 classMap.put("C000024", "军事"); 117 118 // 计算各个类别的样本数 119 Set<String> keySet = classMap.keySet(); 120 121 // 所有词汇的集合,是为了计算每个单词在多少篇文章中出现,用于后面计算df 122 final Set<String> allWords = new HashSet<String>(); 123 124 // 存放每个类别的文件词汇内容 125 final Map<String, List<String[]>> classContentMap = new ConcurrentHashMap<String, List<String[]>>(); 126 127 for (String classKey : keySet) { 128 129 Participle participle = new IkParticiple(); 130 Map<String, Double> wordMap = new HashMap<String, Double>(); 131 File f = new File(trainPath + File.separator + classKey); 132 File[] files = f.listFiles(new FileFilter() { 133 134 @Override 135 public boolean accept(File pathname) { 136 if (pathname.getName().endsWith(".txt")) { 137 return true; 138 } 139 return false; 140 } 141 142 }); 143 144 // 存储每个类别的文件词汇向量 145 List<String[]> fileContent = new ArrayList<String[]>(); 146 if (files != null) { 147 for (File txt : files) { 148 String content = readtxt(txt.getAbsolutePath()); 149 // 分词 150 String[] word_arr = participle.participle(content, false); 151 fileContent.add(word_arr); 152 // 统计每个词出现的个数 153 for (String word : word_arr) { 154 if (wordMap.containsKey(word)) { 155 Double wordCount = wordMap.get(word); 156 wordMap.put(word, wordCount + 1); 157 } else { 158 wordMap.put(word, 1.0); 159 } 160 161 } 162 } 163 } 164 165 // 每个类别对应的词典和频数 166 classWordMap.put(classKey, wordMap); 167 168 // 每个类别的文章数目 169 classP.put(classKey, files.length); 170 actCount.addAndGet(files.length); 171 classContentMap.put(classKey, fileContent); 172 173 } 174 175 176 177 178 179 // 把训练好的训练器对象序列化到本地 (空间换时间) 180 FileOutputStream fos; 181 try { 182 fos = new FileOutputStream(SERIALIZABLE_PATH); 183 ObjectOutputStream oos = new ObjectOutputStream(fos); 184 oos.writeObject(this); 185 } catch (Exception e) { 186 e.printStackTrace(); 187 } 188 189 } 190 191 /** 192 * 分类 193 * 194 * @param text 195 * <a href="http://my.oschina.net/u/556800" target="_blank" rel="nofollow">@return</a> 返回各个类别的概率大小 196 */ 197 public Map<String, Double> classify(String text) { 198 // 分词,并且去重 199 String[] text_words = participle.participle(text, false); 200 201 Map<String, Double> frequencyOfType = new HashMap<String, Double>(); 202 Set<String> keySet = classMap.keySet(); 203 for (String classKey : keySet) { 204 double typeOfThis = 1.0; 205 Map<String, Double> wordMap = classWordMap.get(classKey); 206 for (String word : text_words) { 207 Double wordCount = wordMap.get(word); 208 int articleCount = classP.get(classKey); 209 210 /* 211 * Double wordidf = idfMap.get(word); if(wordidf==null){ 212 * wordidf=0.001; }else{ wordidf = Math.log(actCount / wordidf); } 213 */ 214 215 // 假如这个词在类别下的所有文章中木有,那么给定个极小的值 不影响计算 216 double term_frequency = (wordCount == null) ? ((double) 1 / (articleCount + 1)) 217 : (wordCount / articleCount); 218 219 // 文本在类别的概率 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。 220 // 当double无限小的时候会归为0,为了避免 *10 221 222 typeOfThis = typeOfThis * term_frequency * 10; 223 typeOfThis = ((typeOfThis == 0.0) ? Double.MIN_VALUE 224 : typeOfThis); 225 // System.out.println(typeOfThis+" : "+term_frequency+" : 226 // "+actCount); 227 } 228 229 typeOfThis = ((typeOfThis == 1.0) ? 0.0 : typeOfThis); 230 231 // 此类别文章出现的概率 232 double classOfAll = classP.get(classKey) / actCount.doubleValue(); 233 234 // 根据贝叶斯公式 $(A|B)=S(B|A)*S(A)/S(B),由于$(B)是常数,在这里不做计算,不影响分类结果 235 frequencyOfType.put(classKey, typeOfThis * classOfAll); 236 } 237 238 return frequencyOfType; 239 } 240 241 public void pringAll() { 242 Set<Entry<String, Map<String, Double>>> classWordEntry = classWordMap 243 .entrySet(); 244 for (Entry<String, Map<String, Double>> ent : classWordEntry) { 245 System.out.println("类别: " + ent.getKey()); 246 Map<String, Double> wordMap = ent.getValue(); 247 Set<Entry<String, Double>> wordMapSet = wordMap.entrySet(); 248 for (Entry<String, Double> wordEnt : wordMapSet) { 249 System.out.println(wordEnt.getKey() + ":" + wordEnt.getValue()); 250 } 251 } 252 } 253 254 public Map<String, String> getClassMap() { 255 return classMap; 256 } 257 258 public void setClassMap(Map<String, String> classMap) { 259 this.classMap = classMap; 260 } 261 262 }
在试验过程中,发觉某篇文章的分类不太准,某篇IT文章分到招聘类别下了,在仔细对比了训练数据后,发觉这是由于招聘类别每篇文章下面都带有“搜狗”的标志,而待分类的这篇IT文章里面充斥这搜狗这类词汇,结果招聘类下的概率比较大。由此想到,在除了做常规的贝叶斯计算时,需要把不同文本中出现次数多的词汇权重降低甚至删除(好比关键词搜索中的tf-idf),通俗点讲就是,在所有训练文本中某词汇(如的,地,得)出现的次数越多,这个词越不重要,比如IT文章中“软件”和“应用”这两个词汇,“应用”应该是很多文章类别下都有的,反而不太重要,但是“软件”这个词汇大多只出现在IT文章里,出现在大量文章的概率并不大。 我这里原本打算计算每个词的idf,然后给定一个阀值来判断是否需要纳入计算,但是由于词汇太多,计算量较大(等待结果时间较长),所以暂时注释掉了。