代码改变世界

全文检索、数据挖掘、推荐引擎系列6---基于KMean的文本自动算法

2011-08-24 16:28  java ee spring  阅读(264)  评论(0编辑  收藏  举报

对一系列文章进行自动聚类可以做为基于内容的推荐引擎的基础,如果要实现文本的自动聚类,首先按照本系列5中所介绍的,对文章进行分词,然后计算得出文章的术语向量表示,即求文章中每个不同的单词以其所对应的TF*IDF,具体计算方法如5中所示。目前文本自动聚类算法中,用得最多是KMean算法,本文中就介绍KMean算法的应用。当然,KMean算法可以通过调用Mahout或WEKA这两个开源的机器学习算法库来实现,但是在这类算法中需要准备比较复杂的输入文件,预处理过程比较复杂,还有一点,我们可能在实际应用中要对KMean算法进行调整,这样自己编写KMean算法重加有助于我们对文本聚类算法的理解。

我们首先定义术语向量类,用来表示每篇文章的术语向量,还包括文档编号和类别编号,具体代码如下所示:

public class SepaTermVector {
 public SepaTermVector() {
  termVector = new Vector<TermInfo>();
 }
 public Vector<TermInfo> getTermVector() {
  return termVector;
 }
 public void setTermVector(Vector<TermInfo> termVector) {
  this.termVector = termVector;
 }
 public int getDocId() {
  return docId;
 }
 public void setDocId(int docId) {
  this.docId = docId;
 }
 public int getClusterId() {
  return clusterId;
 }
 public void setClusterId(int clusterId) {
  this.clusterId = clusterId;
 }
 
 /**
  * 在使用文章的术语向量时,我们不希望在自动聚类过程中破坏文章的术语向量,所以需要完体复
  * 制一份术语向量给自动聚类算法
  */
 @Override
 public SepaTermVector clone() {
  SepaTermVector obj = new SepaTermVector();
  obj.setDocId(docId);
  obj.setClusterId(clusterId);
  Vector<TermInfo> vt = new Vector<TermInfo>();
  for (TermInfo item : termVector) {
   vt.add(item);
  }
  obj.setTermVector(vt);
  return obj;
 }
 private Vector<TermInfo> termVector = null;
 private int docId = -1; // 所属的文章编号
 private int clusterId = -1; // 所属的聚类编号
}

然后我们定义文本聚类的类,在该类中保存聚类编号,聚类的中心(本身是该聚类中所有文章术语向量的一个并集)和聚类中所包含的术语向量(每个术语向量代表一篇文章)。代码如下所示:

public class TextClusterInfo {
 public TextClusterInfo(int clusterId) {
  this.clusterId = clusterId;
  items = new Vector<SepaTermVector>(); // 考虑线程安全性而牺牲部分性能
 }
 
 public void addItem(SepaTermVector item) {
  items.add(item);
 }
 
 public void clearItems() {
  items.clear();
 }
 
 /**
  * 计算本类型的中心点
  */
 public void computeCenter() {
  if (items.size() <= 0) {
   return ;
  }
  for (SepaTermVector item : items) {
   if (null == center) {
    center = item;
   } else {
    center = DocTermVector.calCenterTermVector(item, center);
   }
  }
 }
 
 public int getClusterId() {
  return clusterId;
 }
 public void setClusterId(int clusterId) {
  this.clusterId = clusterId;
 }
 public SepaTermVector getCenter() {
  return center;
 }
 public void setCenter(SepaTermVector center) {
  this.center = center;
 }
 public List<SepaTermVector> getItems() {
  return items;
 }
 public void setItems(List<SepaTermVector> items) {
  this.items = items;
 }
 private int clusterId = 0;
 private SepaTermVector center = null;
 private List<SepaTermVector> items = null;
}

接下来就是KMean自动聚类算法的工具类了,这里需要注意的是标准KMean自动聚类算法中,只需要指定初始的聚类数,然后由算法自动随机选择聚类中心,然后进行迭代,最终找出自动聚类结果。为了降低算法计算强度,我们在实际中不但给出了聚类数量,还给出了每个聚类的中心术语向量,即在大量文本中,找出每个聚类中的一篇代表性文章,作为参数传给自动聚类算法,在我们的实验数据中,可以很快达到收敛的效果,并且准确性很高。

KMean算分为下列几步:

  1. 根据所给出的聚类中心初始化聚类
  2. 清空每个聚类中属于该聚类的术语向量列表
  3. 针对每篇文章的术语向量,求出与其最近的聚类,将该术语向量加入到该聚类,如果上次循环中求出的聚类和本次不同,则表明还需继续运行聚类算法
  4. 计算加入新术语向量后的聚类新的中心
  5. 判断是否还需要运行自动聚类算法,若需要则回到2

衡量术语向量与聚类的相似度采用点积形式,就是术语向量与聚类中心所代表的术语向量相同单词权值之和,值越大代表二者越相像,具体代码如下所示:


 public static double getDotProdTvs(SepaTermVector stv1, SepaTermVector stv2) {
  double dotProd = 0.0;
  Hashtable<String, Double> dict = new Hashtable<String, Double>();
  for (TermInfo info : stv2.getTermVector()) {
   dict.put(info.getTermStr(), info.getWeight());
  }
  for (TermInfo item : stv1.getTermVector()) {
   if (dict.get(item.getTermStr())!= null) {
    dotProd += item.getWeight() * dict.get(item.getTermStr()).doubleValue();
   }
  }
  return dotProd;
 }

下面KMean算法实现类的代码:

public class TextKMeanCluster {
 /**
  * 在通常情况下,我们需要将待分类文本分出大致的类别来,即确定numClusters。在有些情况下,还可以指定某个类别中
  * 的一篇文章。这样可以避免算法不收敛时聚类的质量问题。
  * @param docTermVectors 需要进行聚类的术语向量
  * @param numClusters 聚类数量
  */
 public TextKMeanCluster(List<SepaTermVector> docTermVectors, int numClusters) {
  this.docTermVectors = docTermVectors;
  this.numClusters = numClusters;
 }
 
 /**
  * 对文章进行聚类
  * @param initCenters 聚类的中心点
  * @return 聚类结果
  */
 public List<TextClusterInfo> cluster(List<SepaTermVector> initCenters) {
  if (docTermVectors.size() <= 0) {
   return null;
  }
  initClusters(initCenters);
  boolean hadReassign = true;
  int runTimes = 0;
  while ((runTimes<=MAX_KMEAN_RUNTIMES) && (hadReassign)) {
   System.out.println("runTimes=" + runTimes + "!");
   clearClusterItems();
   hadReassign = reassignClusters();
   computeClusterCenters();
   runTimes++;
  }
  return clusters;
 }
 
 /**
  * 本算法中采用给定聚类中心的方式,但是在标准KMean算法中是随机选择聚类中心的,随机选择收敛较慢。
  */
 public void initClusters(List<SepaTermVector> initCenters) {
  clusters = new Vector<TextClusterInfo>();
  TextClusterInfo cluster = null;
  int i = 0;
  for (SepaTermVector stv : initCenters) {
   cluster = new TextClusterInfo(i++);
   cluster.setCenter(stv);
   clusters.add(cluster);
  }
 }
 
 /**
  * 求出所有文章术语向量的新聚类,如果与上次求出的聚类不同,则表明需要继续运行
  * @return 真时代表需要继续运行自动聚类算法
  */
 public boolean reassignClusters() {
  int numChanges = 0;
  TextClusterInfo newCluster = null;
  for (SepaTermVector termVector : docTermVectors) {
   newCluster = getClosetCluster(termVector);
   if ((termVector.getClusterId()<0) || termVector.getClusterId() != newCluster.getClusterId()) {
    numChanges++;
    termVector.setClusterId(newCluster.getClusterId());
   }
   newCluster.addItem(termVector);
   //System.out.println("reassignCluster:cid=" + newCluster.getClusterId() + ":size=" +
     //newCluster.getItems().size());
  }
  return (numChanges>0);
 }
 
 /**
  * 求出加入新述语向量后聚类的新中心
  */
 public void computeClusterCenters() {
  for (TextClusterInfo cluster : clusters) {
   cluster.computeCenter();
  }
 }
 
 /**
  * 清除该聚类的术语向量列表
  */
 public void clearClusterItems() {
  for (TextClusterInfo cluster : clusters) {
   cluster.clearItems();
  }
 }
 
 /**
  * 在标准KMean算法中随机抽取聚类中心的算法,在本类中该方法暂时未使用
  * @param usedIndex
  * @return
  */
 private SepaTermVector getTermVectorAtRandom(Hashtable<Integer, Integer> usedIndex) {
  boolean found = false;
  int index = -1;
  while (!found) {
   index = (int)Math.floor(Math.random() * docTermVectors.size());
   while (usedIndex.get(index) != null) {
    index = (int)Math.floor(Math.random() * docTermVectors.size());
   }
   usedIndex.put(index, index);
   return docTermVectors.get(index).clone(); // 重新复制一份,不破坏原来的拷贝
  }
  return null;
 }
 
 /**
  * 对术语向量和所有聚类中心所代表的术语向量做点积,取值最大的聚类为该文档的聚类
  * @param termVector 术语向量
  * @return 与该术语向量最接近的聚类
  */
 private TextClusterInfo getClosetCluster(SepaTermVector termVector) {
  TextClusterInfo closetCluster = null;
  double dotProd = -1.0;
  double maxDotProd = -2.0;
  double dist = -1.0;
  double smallestDist = Double.MAX_VALUE;
  for (TextClusterInfo cluster : clusters) {
   //dist = DocTermVector.calTermVectorDist(cluster.getCenter(), termVector);
   dotProd = DocTermVector.getDotProdTvs(cluster.getCenter(), termVector);
   //System.out.println("getClosetCluster:dotProd=" + dotProd + "[" + maxDotProd + "] docId="
     //+ termVector.getDocId() + "!");
   //if (dist < smallestDist) {
   if (dotProd > maxDotProd) {
    //smallestDist = dist;
    maxDotProd = dotProd;
    closetCluster = cluster;
   }
  }
  return closetCluster;
 }
 
 public final static int MAX_KMEAN_RUNTIMES = 1000;
 
 private List<SepaTermVector> docTermVectors = null; // 所有文章的术语向量
 private List<SepaTermVector> centers = null;
 private List<TextClusterInfo> clusters = null; // 所有聚类
 private int numClusters = 0;
}

具体的调用方法如下所示:

DocTermVector.init();
  // 技术类
  int doc1Id = FteEngine.genTermVector(-1, "Java语言编程技术详解", "", "", "", "");
  int doc2Id = FteEngine.genTermVector(-1, "C++语言编程指南", "", "", "", "");
  int doc4Id = FteEngine.genTermVector(-1, "Python程序设计教程", "", "", "", "");
  // 同性恋
  int doc3Id = FteEngine.genTermVector(-1, "同性恋网站变身电子商务网站", "", "", "", "");
  int doc5Id = FteEngine.genTermVector(-1, "同性恋网站大全", "", "", "", "");
  int doc6Id = FteEngine.genTermVector(-1, "男同性恋特点", "同性恋", "", "", "");
  // 天使投资
  int doc7Id = FteEngine.genTermVector(-1, "天使投资社交网络", "", "", "", "");
  int doc8Id = FteEngine.genTermVector(-1, "天使投资发展概况", "", "", "", "");
  int doc9Id = FteEngine.genTermVector(-1, "著名天使投资人和天使投资机构", "", "", "", "");
  // 环境保护
  int doc10Id = FteEngine.genTermVector(-1, "环境保护技术分析", "", "", "", "");
  int doc11Id = FteEngine.genTermVector(-1, "环境保护与碳关税分析", "", "", "", "");
  int doc12Id = FteEngine.genTermVector(-1, "环境保护与我国经济发展趋势", "", "", "", "");
  
  FteEngine.genTermVector(-1, "VB编程指南", "", "", "", "");
  FteEngine.genTermVector(-1, "天使投资社区天使街正式上线运行", "", "", "", "");
  FteEngine.genTermVector(-1, "年度编程语言评选活动", "", "", "", "");

List<SepaTermVector> centers = new Vector<SepaTermVector>();
  centers.add(DocTermVector.getDocTermVector(0));
  centers.add(DocTermVector.getDocTermVector(3));
  centers.add(DocTermVector.getDocTermVector(6));
  centers.add(DocTermVector.getDocTermVector(9));
  TextKMeanCluster tkmc = new TextKMeanCluster(DocTermVector.getDocTermVectors(), 4);
  List<TextClusterInfo> rst = tkmc.cluster(centers);
  String lineStr = null;
  for (TextClusterInfo info : rst) {
   lineStr = "" + info.getClusterId() + "(" + info.getItems().size() + "):";
   for (SepaTermVector tvItem : info.getItems()) {
    lineStr += " " + tvItem.getDocId();
   }
   lineStr += "^_^";
   System.out.println(lineStr);
  }

运行的结果为:

0(5): 0 1 2 12 14^_^
1(3): 3 4 5^_^
2(4): 6 7 8 13^_^
3(3): 9 10 11^_^

由上面的结果来看,实现了基本正确的聚类。