计算信息增益,选择特征词
上一节我们已经对训练集建立了word-doc矩阵,每读取矩阵的一行就可以计算出term对应的IG值。最后把结果写入文件。
信息增益的计算公式参见我的另一篇博客信息论。
代码如下:
View Code
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.PrintWriter;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
public class CalIG {
public void calIG(File matrixFile,File IGFile) {
if (!matrixFile.exists()) {
System.out.println("Matrix文件不存在.程序退出.");
System.exit(2);
}
int category_num = 9; //一共有9大分类
int doc_num=7196;//总共有7196篇文档,也是word-doc矩阵的列数
int[] category_count={1070,440,513,816,750,756,1392,473,986}; //每个分类包含的文档数
double HC=getEntropy(category_count);
try {
FileReader fr = new FileReader(matrixFile);
BufferedReader br = new BufferedReader(fr);
PrintWriter pw=new PrintWriter(new FileOutputStream(IGFile));
String line = null;
while ((line = br.readLine()) != null) {
String[] content = line.split("\\s+");
String term = content[0];
ArrayList<Short> al = new ArrayList<Short>(doc_num);
for (int i = 0; i < doc_num; i++) {
short count = Short.parseShort(content[i + 1]);
al.add(count);
}
int term_count = 0; // 出现term的文档数量
int[] term_class_count = new int[category_num];// 每个类别中出现term的文档数量
int[] term_b_class_count = new int[category_num];// 每个类别中不出现term的文档数量
int index=0;
for (int i = 0; i < category_num; i++) {
for (int j = 0; j < category_count[i]; j++) {
if (al.get(index) > 0) {
term_class_count[i]++;
}
index++;
}
term_b_class_count[i]= category_count[i]-term_class_count[i];
term_count += term_class_count[i];
}
double HCT=1.0*term_count/doc_num*getEntropy(term_class_count)+1.0*(doc_num-term_count)/doc_num*getEntropy(term_b_class_count);
double IG = HC - HCT;
pw.println(term+"\t"+String.valueOf(IG));
pw.flush();
}
br.close();
pw.close();
} catch (Exception e) {
e.printStackTrace();
}
}
public double getEntropy(int[] arr){
int sum=0;
double entropy=0.0;
for(int i=0;i<arr.length;i++){
sum+=arr[i];
entropy+=arr[i]*Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
}
entropy/=sum;
entropy-=Math.log(sum)/Math.log(2);
return 0-entropy;
}
public static void main(String[] args) throws Exception{
Date currentTime = new Date();
SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
System.out.println("Begin Time: "+formatter.format(currentTime));
CalIG inst=new CalIG();
File in=new File("/home/orisun/matrix/part-r-00000");
File out=new File("/home/orisun/frequency1");
inst.calIG(in, out);
currentTime = new Date();
System.out.println("End Time: "+formatter.format(currentTime));
}
}
本来生成了大约32.6万个term-ig对,后来我去掉那些以数字、字母、符号开头的term,剩下117342个term。假如需要4000个特征项,那么我们就要按取出ig值最大的前4000个term(注意并不需要对117342条记录全部排序)。可以看我的另外一篇博客《寻找N个元素中的前K个最大者》,这里就不再重复介绍了,直接上代码:
#ifndef ITERM_H #define ITERM_H #include<string> using namespace std; class iterm{ private: string term; double ig; public: iterm(string term,double ig):term(term),ig(ig){} bool operator == (const iterm & i2){ return ig==i2.ig; } bool operator > (const iterm & i2){ return ig>i2.ig; } bool operator < (const iterm & i2){ return ig<i2.ig; } ostream& operator << (ostream& out){ out<<term<<"\t"<<ig; return out; } string getTerm(){ return term; } double getIG(){ return ig; } }; #endif
#include<iostream> #include<cstdlib> #include<ctime> #include<vector> #include<fstream> #include<sstream> #include"iterm.h" template<typename Comparable> void percolate(vector<Comparable> &vec,int index){ int i=index; int j=2*i+1; while(j<vec.size()){ if(j<vec.size()-1 && vec[j]>vec[j+1]) j++; if(vec[i]<vec[j]) break; else{ swap(vec[i],vec[j]); i=j; j=2*i+1; } } } template<typename Comparable> void buildHeap(vector<Comparable> &vec){ int len=vec.size(); for(int i=(len-1)/2;i>=0;i--) percolate(vec,i); } int main(){ clock_t t1=clock(); const int K=4000; vector<iterm> vec; string infn="frequency1"; string outfn="features2"; ifstream infile(infn.c_str(),ios::in); ofstream outfile(outfn.c_str(),ios::out); string line,word,sd; double d; int n=K; while(getline(infile,line)){ istringstream sstr(line); sstr>>word; sstr>>sd; d=atof(sd.c_str()); iterm inst(word,d); if(n>0){ vec.push_back(inst); n--; } else{ if(n==0){ buildHeap(vec); n=-1; } if(inst>vec[0]){ vec[0]=inst; percolate(vec,0); } } } infile.close(); for(int i=0;i<K;i++) outfile<<vec[i].getTerm()<<"\t"<<vec[i].getIG()<<endl; outfile.close(); clock_t t2=clock(); cout<<"Time:"<<(t2-t1)<<endl; return 0; }
这段时间一直Java,今天又体验了一把C++的极速,执行上面代码只用了0.25秒!
如果实际工作需要,可以再对这4000个term来一次内排序。
本文来自博客园,作者:高性能golang,转载请注明原文链接:https://www.cnblogs.com/zhangchaoyang/articles/2236475.html