原创:Kmeans算法实战+改进(java实现)

kmeans算法的流程:
 

  EM思想很伟大,在处理含有隐式变量的机器学习算法中很有用。聚类算法包括kmeans,高斯混合聚类,快速迭代聚类等等,都离不开EM思想。在了解kmeans算法之前,有必要详细了解一下EM思想。

  Kmeans算法属于无监督学习中的一种,相比于监督学习,能节省很多成本,省去了大量的标签标注。每个数据点都有自己的隐式的分类。聚类要做的是,从中选取出数个聚类中心,对数据集进行初始聚类。此后,通过更新聚类中心(把簇中心缓存起来),重新聚类,然后再更新簇中心,如果此簇中心与旧的簇中心的差值(2范数)<阈值,说明聚类趋于稳定,迭代结束。Kmeans算法,通过计算两个数据点的欧氏距离(2范数),来对数据点进行归类。这个算法和高斯混合聚类相比,要死板很多,而且,有一个最重要的弱点,就是聚类结果对初始化的簇中心比较敏感,而且容易陷入局部最优。因为,评价kmeans的损失函数属于非凸函数,不能取得全局最优解。稍后,在代码中,会有说明。如果想改进这个算法,可以考虑半聚类算法,与其他算法结合起来,削弱其弱点 。

      关于算法的研究,本人人为,应该从以下三方面着手:第一境界,明白原理,从理论上获得支撑;第二层面,深刻理解算法实现,能够根据高数进行推导,并且找出算法的优劣点;第三层面,能够证明算法的正确性,并且提出改进方案。Kmeans算法是基于EM思想,Kmeans算法的挑战在于如何提高聚类的准确性和稳定性。 在改进上主要朝着上述两个方向努力。改进的时候,首先要提出理论上的支持,在实施上,主要手段围绕着改进簇中心的选取方式以及挖掘出k值的隐式最优值。改进簇中心选取方式的目标就是提高准确率和稳定性,挖掘k值隐式最优解是为了提高聚类的颗粒度,追求最优效果。因为使用算法的人,不一定保证能真正深刻理解算法,并且对于训练数据的内部规律,也不一定清晰。而且Kmeans算法,人为地在外部设置k值,这种做法,本身就存在一定的不合理性。不像监督学习,训练数据的标签,可以按照人的想法进行划分,比如设置3类,或者4类。但是,自动聚类,机器并不能做到人这么智能化。所以,关于k值的设定,有必要改进一下,让机器在一定程度上,自动识别出最优解。这样,在外部调用算法时,当用户设置的k值<隐式最优解的时候,按照k值数目进行聚类,当用户设置的k值很大时,超出了k值的隐式最优解,算法内部应该能够自动调整k值为最优解。这就是方向,有了方向后,就可以沿着这个思路去思考,尝试,测试,直到成功。另外好的算法,从代码层面上看,大都是简单易于执行的,乍一看,就那么几个数据结构。但是能够提出想法,并且从理论上需求突破,这才是最难的。最好的事务,都是很朴实的,使用起来很简单,比如微软提出来的全排列最优算法。      

      下面,上传本人最近编写的Kmeans算法,这个算法中,有三个地方进行了改进:①增加了数据的归一化处理,以消除大的数据的影响;②增加了数据归类算法,使输出的数据同一类别的,连续存储,使输出结果更加人性化;③使簇中心的选取方式及个数约束更加合理化。追求的效果:一为准确,二为稳定,三消除簇中心的敏感性(实际上,关于这一点永远不能消除,只能最大限度地提升准确率)。

  首先,展示一下未改进前的算法:    

package com.txq.kmeans;

/**
 *
 * @param <b>data</b> <i>in double[length][dim]</i><br/>length个instance的坐标,第i(0~length-1)个instance为data[i]
 * @param <b>length</b> <i>in</i> instance个数
 * @param <b>dim</b> <i>in</i> instance维数
 * @param <b>labels</b> <i>out int[length]</i><br/>聚类后,instance所属的聚类标号(0~k-1)
 * @param <b>centers</b> <i>in out double[k][dim]</i><br/>k个聚类中心点的坐标,第i(0~k-1)个中心点为centers[i]
 * @author Yuanbo She
 *
 */
public class Kmeans_data {
 public double[][] data;//原始矩阵
 public int length;//矩阵长度
 public int dim;//特征维度
 public int[] labels;//数据所属类别的标签,即聚类中心的索引值
 public double[][] centers;//聚类中心矩阵
 public int[] centerCounts;//每个聚类中心的元素个数
 public double [][]originalCenters;//最初的聚类中心坐标点集
 public Kmeans_data(double[][] data, int length, int dim) {
  this.data = data;
  this.length = length;
  this.dim = dim;  
 }
}

然后,定义聚类所需的参数:

public class Kmeans_param {
 public static final int CENTER_ORDER = 0;
 public static final int CENTER_RANDOM = 1;
 public static final int MAX_ATTEMPTS = 4000;
 public static final double MIN_CRITERIA = 1.0;
 public static final double MIN_EuclideanDistance = 0.8;
 public double criteria = MIN_CRITERIA; //阈值
 public int attempts = MAX_ATTEMPTS; //尝试次数
 public int initCenterMethod = CENTER_RANDOM ; //初始化聚类中心点方式
 public boolean isDisplay = true; //是否直接显示结果

 public double min_euclideanDistance = MIN_EuclideanDistance;
}

还要定义聚类显示的结果:

/**
 *
 * 聚类显示的结果
 * @author TongXueQiang
 */
public class Kmeans_result {
    public int attempts; // 退出迭代时的尝试次数
    public double criteriaBreakCondition; // 退出迭代时的最大距离(小于阈值)
    public int k; // 聚类数
    public int perm[];//归类后连续存放的原始数据索引 
 public int start[];//每个类在原始数据中的起始位置
}

接下来,开始聚类:

package com.txq.kmeans;

import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
 * Kmeans聚类算法
 * @author TongXueQiang
 * @date 2016/11/09
 */
public class Kmeans {
 private static DecimalFormat df = new DecimalFormat("#####.00");//对数据格式化处理

 public Kmeans_data data = null;

 public Kmeans(double [][]da){

  data = new Kmeans_data(da,da.length,da[0].length);

}
 /**
  * double[][] 元素全置
  *
  * @param matrix
  *            double[][]
  * @param highDim
  *            int
  * @param lowDim
  *            int <br/>
  *            double[highDim][lowDim]
  */
 private static void setDouble2Zero(double[][] matrix, int highDim, int lowDim) {
  for (int i = 0; i < highDim; i++) {
   for (int j = 0; j < lowDim; j++) {
    matrix[i][j] = 0;
   }
  }
 }

 /**
  * 拷贝源二维矩阵元素到目标二维矩阵。 foreach (dests[highDim][lowDim] =
  * sources[highDim][lowDim]);
  *
  * @param dests
  *            double[][]
  * @param sources
  *            double[][]
  * @param highDim
  *            int
  * @param lowDim
  *            int
  */
 private static void copyCenters(double[][] dests, double[][] sources, int highDim, int lowDim) {
  for (int i = 0; i < highDim; i++) {
   for (int j = 0; j < lowDim; j++) {
    dests[i][j] = sources[i][j];
   }
  }
 }

 /**
  * 更新聚类中心坐标,实现思路为:先求簇中心的和,然后求取均值。
  *
  * @param k
  *            int 分类个数
  * @param data
  *            kmeans_data
  */
 private static void updateCenters(int k, Kmeans_data data) {
  double[][] centers = data.centers;  
  setDouble2Zero(centers, k, data.dim);//归零处理
  int[] labels = data.labels;
  int[] centerCounts = data.centerCounts;
  for (int i = 0; i < data.dim; i++) {
   for (int j = 0; j < data.length; j++) {
    centers[labels[j]][i] += data.data[j][i];
   }
  }
  for (int i = 0; i < k; i++) {
   for (int j = 0; j < data.dim; j++) {
    centers[i][j] = centers[i][j] / centerCounts[i];
    centers[i][j] = Double.valueOf(df.format(centers[i][j]));
   }
  }
 }

 /**
  * 计算两点欧氏距离
  *
  * @param pa
  *            double[]
  * @param pb
  *            double[]
  * @param dim
  *            int 维数
  * @return double 距离
  */
 public static double dist(double[] pa, double[] pb, int dim) {
  double rv = 0;
  for (int i = 0; i < dim; i++) {
   double temp = pa[i] - pb[i];
   temp = temp * temp;
   rv += temp;
  }
  return Math.sqrt(rv);
 }

 /**
  * 做Kmeans运算
  *
  * @param k
  *            int 聚类个数
  * @param data
  *            kmeans_data kmeans数据类
  * @param param
  *            kmeans_param kmeans参数类
  * @return kmeans_result kmeans运行信息类
  */
 public static Kmeans_result doKmeans(int k, Kmeans_param param) {
  //对数据进行规一化处理,以消除大的数据的影响
  normalize(data);
//  System.out.println("规格化处理后的数据:");
//  for (int i = 0;i < data.length;i++) {
//   for (int j = 0;j < data.dim;j++) {
//    System.out.print(data.data[i][j] + " ");
//   }
//   System.out.println();
//  }


  // 预处理
  double[][] centers = new double[k][data.dim]; // 聚类中心点集
  data.centers = centers;
  int[] centerCounts = new int[k]; // 各聚类的包含点个数
  data.centerCounts = centerCounts;
  Arrays.fill(centerCounts, 0);
  int[] labels = new int[data.length]; // 各个点所属聚类标号
  data.labels = labels;
  double[][] oldCenters = new double[k][data.dim]; // 临时缓存旧的聚类中心坐标
  
  // 初始化聚类中心(随机或者依序选择data内的k个不重复点)
  if (param.initCenterMethod == Kmeans_param.CENTER_RANDOM) { // 随机选取k个初始聚类中心
   Random rn = new Random();
   List<Integer> seeds = new LinkedList<Integer>();
   while (seeds.size() < k) {
    int randomInt = rn.nextInt(data.length);
    if (!seeds.contains(randomInt)) {
     seeds.add(randomInt);
    }
   }
   Collections.sort(seeds);
   for (int i = 0; i < k; i++) {
    int m = seeds.remove(0);
    for (int j = 0; j < data.dim; j++) {
     centers[i][j] = data.data[m][j];
    }
   }
  } else { // 选取前k个点位初始聚类中心
   for (int i = 0; i < k; i++) {
    for (int j = 0; j < data.dim; j++) {
     centers[i][j] = data.data[i][j];
    }
   }
  }
  //给最初的聚类中心赋值
  data.originalCenters = new double[k][data.dim];
  for (int i = 0; i < k; i++) {
   for (int j = 0; j < data.dim; j++) {
    data.originalCenters[i][j] = centers[i][j];
   }
  }
  
  // 第一轮迭代
  for (int i = 0; i < data.length; i++) {
   double minDist = dist(data.data[i], centers[0], data.dim);
   int label = 0;
   for (int j = 1; j < k; j++) {
    double tempDist = dist(data.data[i], centers[j], data.dim);
    if (tempDist < minDist) {
     minDist = tempDist;
    label = j;
    }
   }
   labels[i] = label;
   centerCounts[label]++;
  }
  updateCenters(k, data);//更新簇中心
  copyCenters(oldCenters, centers, k, data.dim);

  // 迭代预处理
  int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS;
  int attempts = 1;
  double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA;
  double criteriaBreakCondition = 0;
  boolean[] flags = new boolean[k]; // 标记哪些中心被修改过

  // 迭代
  iterate: while (attempts < maxAttempts) { // 迭代次数不超过最大值,最大中心改变量不超过阈值
   for (int i = 0; i < k; i++) { // 初始化中心点“是否被修改过”标记
    flags[i] = false;
   }
   for (int i = 0; i < data.length; i++) { // 遍历data内所有点
    double minDist = dist(data.data[i], centers[0], data.dim);
    int label = 0;
    for (int j = 1; j < k; j++) {
     double tempDist = dist(data.data[i], centers[j], data.dim);
     if (tempDist < minDist) {
      minDist = tempDist;
      label = j;
     }
    }
    if (label != labels[i]) { // 如果当前点被聚类到新的类别则做更新
     int oldLabel = labels[i];
     labels[i] = label;
     centerCounts[oldLabel]--;
     centerCounts[label]++;
     flags[oldLabel] = true;
     flags[label] = true;
    }
   }
   updateCenters(k, data);
   attempts++;

   // 计算被修改过的中心点最大修改量是否超过阈值
   double maxDist = 0;
   for (int i = 0; i < k; i++) {
    if (flags[i]) {
     double tempDist = dist(centers[i], oldCenters[i], data.dim);
     if (maxDist < tempDist) {
      maxDist = tempDist;
     }
     for (int j = 0; j < data.dim; j++) { // 更新oldCenter
      oldCenters[i][j] = centers[i][j];
      oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j]));
     }
    }
   }
   if (maxDist < criteria) {
    criteriaBreakCondition = maxDist;
    break iterate;
   }
  }

  // 输出信息,把属于同一类的数据连续存放  
  Kmeans_result rvInfo = new Kmeans_result();
  int perm[] = new int[data.length];
  rvInfo.perm = perm;
  int start[] = new int[k];
  rvInfo.start = start;
  group_class(perm,start,k,data);
  
  rvInfo.attempts = attempts;
  rvInfo.criteriaBreakCondition = criteriaBreakCondition;
  if (param.isDisplay) {
   System.out.println("最初的聚类中心:");
   for(int i = 0;i < data.originalCenters.length;i++){
    for(int j = 0;j < data.dim;j++){
     System.out.print(data.originalCenters[i][j]+" ");
    }
    System.out.print("\t类别:"+i+"\t"+"总数:"+centerCounts[i]);
    System.out.println();
   }
   System.out.println("\n聚类结果--------------------------->");

   int originalCount = 0;
   for (int i = 0;i < k;i++) {
    int index = data.labels[perm[start[i]]];//所属类别
    int count = data.centerCounts[index];//类别中个体数目
    originalCount += count;
    System.out.println("所属类别:" + index);
    for (int j = start[i];j < originalCount;j++) {
     for (double num:data.data[perm[j]]) {
      System.out.print(num+" ");
     }
     System.out.println();
    }
   }
  }
  return rvInfo;
 }
 /**
  * @author TongXueQiang
  * @param perm 连续存放归类后的原始数据的索引
  * @param start 每个类的起始索引位置
  * @param k 聚类中心个数
  * @param data 原始数据---二维矩阵
  */
 private static void group_class(int perm[],int start[],int k,Kmeans_data data){  
  start[0] = 0;
  for(int i = 1;i < k;i++){
   start[i] = start[i-1] + data.centerCounts[i-1];
  }
  
  for(int i = 0;i < data.length;i++){
   perm[start[data.labels[i]]++] = i;   
  }
  
  start[0] = 0;
  for(int i = 1;i < k;i++){
   start[i] = start[i-1] + data.centerCounts[i-1];
  }
 }
 /**
  * 规一化处理
  * @param data
  * @author TongXueQiang
  */
 private static void normalize(Kmeans_data data){
  //1.首先计算各个列的最大和最小值,存入map中
  Map<Integer,Double[]> minAndMax = new HashMap<Integer,Double[]>();
  for(int i = 0;i < data.dim;i++){
   Double []nums = new Double[2];
   double max = data.data[0][i];
   double min = data.data[data.length-1][i];
   for(int j = 0;j < data.length;j++){
    if(data.data[j][i] > max){
     max = data.data[j][i];
    }
    if(data.data[j][i] < min){
     min = data.data[j][i];
    }    
   }
   nums[0] = min; nums[1] = max;
   minAndMax.put(i,nums);
  }
  //2.更新矩阵的值
  for(int i = 0;i < data.length;i++){
   for(int j = 0;j < data.dim;j++){
    double minValue = minAndMax.get(j)[0];
    double maxValue = minAndMax.get(j)[1];
    data.data[i][j] = (data.data[i][j] - minValue)/(maxValue - minValue);
    data.data[i][j] = Double.valueOf(df.format(data.data[i][j]));
   }
  }
 }
}

测试类:

package com.txq.kmeans.test;

import org.junit.Test;
import com.txq.kmeans.Kmeans;
import com.txq.kmeans.Kmeans_data;
import com.txq.kmeans.Kmeans_param;

public class KmeansTest {
 
 @Test
 public void test() {
  double [][]da = new double[6][];  
  da[0] = new double[]{1,5,132};
  da[1] = new double[]{3,7,12};
  da[2] = new double[]{67,23,45};
  da[3] = new double[]{34,5,13};
  da[4] = new double[]{12,7,21};
  da[5] = new double[]{26,23,54};
  Kmeans kmeans = new Kmeans(da);
  kmeans.doKmeans(3);
 }
}

 输出结果,注意观察:  
 
最初的聚类中心:
0.0 0.0 1.0  类别:0 总数:1
0.03 0.11 0.0  类别:1 总数:3
0.5 0.0 0.01  类别:2 总数:2

聚类结果--------------------------->
所属类别:0
0.0 0.0 1.0
所属类别:1
0.03 0.11 0.0
0.5 0.0 0.01
0.17 0.11 0.07
所属类别:2
1.0 1.0 0.28
0.38 1.0 0.35

观察这个结果,发现,随机初始化的三个簇中心,其中有两个的欧氏距离非常接近,属于同一类的。这种情况,聚类结果,就会有偏差,很不合理。

最初的聚类中心:
0.03 0.11 0.0  类别:0 总数:4
1.0 1.0 0.28  类别:1 总数:1
0.38 1.0 0.35  类别:2 总数:1

聚类结果--------------------------->
所属类别:0
0.0 0.0 1.0
0.03 0.11 0.0
0.5 0.0 0.01
0.17 0.11 0.07
所属类别:1
1.0 1.0 0.28
所属类别:2
0.38 1.0 0.35

最初的聚类中心:
1.0 1.0 0.28  类别:0 总数:1
0.5 0.0 0.01  类别:1 总数:4
0.38 1.0 0.35  类别:2 总数:1

聚类结果--------------------------->
所属类别:0
1.0 1.0 0.28
所属类别:1
0.0 0.0 1.0
0.03 0.11 0.0
0.5 0.0 0.01
0.17 0.11 0.07
所属类别:2
0.38 1.0 0.35

      上述算法中,对初始簇中心严重依赖。具体来说,采用随机初始化的方式,聚类结果很不稳定,而且严重影响准确率。选取簇中心的原则是,每两个中心之间的欧氏距离应该尽量大。而且,k的数目应该有隐式的约束,太少或者太大都不合理。所以,应该同时约束上述两个因素。最好的办法是,用概率密度分析,比如高斯分布。把所有的训练数据中每两个数据的欧氏距离看作是基本变量,遵循Gaussian分布。原始数据全部归一化处理后,欧氏距离取值范围应该在:(0,√n)之间,借鉴二元分类的思想,取均值,如果大于均值的话,属于同一类的概率比较大,反之较小。其中,n为dimension.按照此种方法处理的话,会隐式地约束K的个数,使之更加合理。比如,训练数据中,每两个数据的欧氏距离>mean的中心点可能只有3个,如果你在外部调用算法时,人为地设定为4个或者5个的话,应该自动把K值降低为合理值。这样聚类的结果,一定是最优的。所以,要想达到最优效果,外部传递k值的时候,可以尽量地大,或者不设置,在不断测试的过程中,发现改为顺寻扫描效果更佳。但是,会增加时间复杂度。关于算法的精确度和时间复杂度,往往不能两全。转化为工程应用时,可以在牺牲一定精度的前提下,换取时间复杂度的提升。比如,在计算训练数据的欧氏距离的均值的时候,可以只考虑矩阵中的第一个数据与其他所有数据的欧氏距离,计算最大值和最小值,然后折中处理,不能计算所有的组合情况的E(期望)。准确度很高,而且把时间复杂度降低了一个数量级,原来O(n^2)变为O(n)。代码如下:

package com.txq.kmeans;

import java.util.Map;

/** 
 * 聚类模型
 * @author TongXueQiang
 * @date 2017/09/09
 */
public class ClusterModel {
	public double originalCenters[][];
	public int centerCounts[];
    public int attempts; //最大迭代次数
    public double criteriaBreakCondition; // 迭代结束时的最小阈值
    public int[] labels;
    public int k;
    public int perm[];//连续存放的样本	
	public int start[];//每个中心开始的位置
	public Map<String,Integer> identifier;	
	public Kmeans_data data;
	public Map<Integer, String> iden0;	
	
	public void centers(){
		System.out.println("聚类中心:");
		for (int i = 0; i < originalCenters.length; i++) {
			for (int j = 0; j < originalCenters[0].length; j++) {
				System.out.print(originalCenters[i][j] + " ");
			}
			System.out.print("\t"+"第" + (i+1)+"类:" + "\t" + "样本个数:" + centerCounts[i]);
			System.out.println();
		}
	}
	
	public int predict(String iden){		
		int label = labels[identifier.get(iden)];
		return label;
	}	
	
	public void outputAllResult(){
		System.out.println("\n最后聚类结果--------------------------->");

		int originalCount = 0;
		for (int i = 0; i < k; i++) {
			int index = labels[perm[start[i]]];
			int counts = centerCounts[index];
			originalCount += counts;
			System.out.println("第"+(index+1)+"类成员:");
			for (int j = start[i]; j < originalCount; j++) {
				for (double num : data.data[perm[j]]) {
					System.out.print(num + " ");
				}
				System.out.print(":"+iden0.get(perm[j]));
				System.out.println();				
			}
		}
	}
}

 

package com.txq.kmeans;

/**
 * 
 * @author TongXueQiang
 * @param data 原始矩阵
 * @param labels 样本所属类别
 * @param centers 聚类中心
 * @date 2017/09/09
 */
public class Kmeans_data {
	public double[][] data;
	public int length;
	public int dim;	
	public double[][] centers;	
	
	public Kmeans_data(double[][] data, int length, int dim) {
		this.data = data;
		this.length = length;
		this.dim = dim;		
	}
}

 

package com.txq.kmeans;

/**
 * 控制k_means迭代的参数
 * @author TongXueQiang
 * @date 2017/09/09
 */
public class Kmeans_param {
	public static final int K = 240;//系统默认的最大聚类中心个数	
	public static final int MAX_ATTEMPTS = 4000;//最大迭代次数
	public static final double MIN_CRITERIA = 0.1;
	public static final double MIN_EuclideanDistance = 0.8;
	public double criteria = MIN_CRITERIA; //最小阈值
	public int attempts = MAX_ATTEMPTS; 	
	public boolean isDisplay = true; 
	public double min_euclideanDistance = MIN_EuclideanDistance;
}

 

package com.txq.kmeans;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * kMeans聚类算法 
 * @author TongXueQiang
 * @date 2017/09/09
 */
public class Kmeans {
	private DecimalFormat df = new DecimalFormat("#####.00");
	public Kmeans_data data = null;
	// feature,样本名称和索引映射
	private Map<String, Integer> identifier = new HashMap<String, Integer>();
	private Map<Integer, String> iden0 = new HashMap<Integer, String>();
	private ClusterModel model = new ClusterModel();

	/**
	 * 文件到矩阵的映射 
	 * @param path
	 * @return
	 * @throws Exception
	 */
	public double[][] fileToMatrix(String path) throws Exception {
		List<String> contents = new ArrayList<String>();
		model.identifier = identifier;
		model.iden0 = iden0;
		
		FileInputStream file = null;
		InputStreamReader inputFileReader = null;
		BufferedReader reader = null;
		String str = null;
		int rows = 0;
		int dim = 0;
		
        try {
            file = new FileInputStream(path);
            inputFileReader = new InputStreamReader(file, "utf-8");
            reader = new BufferedReader(inputFileReader);
            // 一次读入一行,直到读入null为文件结束
            while ((str = reader.readLine()) != null) {
            	contents.add(str);
    			++rows;
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }

		String[] strs = contents.get(0).split(":");
		dim = strs[0].split(" ").length;

		double[][] da = new double[rows][dim];

		for (int j = 0; j < contents.size(); j++) {
			strs = contents.get(j).split(":");			
			identifier.put(strs[1], j);
			iden0.put(j, strs[1]);
			String[] feature = strs[0].split(" ");
			for (int i = 0; i < dim; i++) {				
				da[j][i] = Double.parseDouble(feature[i]);
			}
		}		
		return da;
	}

	/**
	 * 清零操作
	 * @param matrix
	 * @param highDim
	 * @param lowDim
	 */
	private void setDouble2Zero(double[][] matrix, int highDim, int lowDim) {
		for (int i = 0; i < highDim; i++) {
			for (int j = 0; j < lowDim; j++) {
				matrix[i][j] = 0;
			}
		}
	}

	/**
	 * 聚类中心拷贝 
	 * @param dests
	 * @param sources
	 * @param highDim
	 * @param lowDim
	 */
	private void copyCenters(double[][] dests, double[][] sources, int highDim, int lowDim) {
		for (int i = 0; i < highDim; i++) {
			for (int j = 0; j < lowDim; j++) {
				dests[i][j] = sources[i][j];
			}
		}
	}

	/**
	 * 更新聚类中心
	 * @param k
	 * @param data
	 */
	private void updateCenters(int k, Kmeans_data data) {
		double[][] centers = data.centers;
		setDouble2Zero(centers, k, data.dim);
		int[] labels = model.labels;
		int[] centerCounts = model.centerCounts;
		for (int i = 0; i < data.dim; i++) {
			for (int j = 0; j < data.length; j++) {
				centers[labels[j]][i] += data.data[j][i];
			}
		}
		for (int i = 0; i < k; i++) {
			for (int j = 0; j < data.dim; j++) {
				centers[i][j] = centers[i][j] / centerCounts[i];				
			}
		}
	}

	/**
	 * 计算欧氏距离 
	 * @param pa
	 * @param pb
	 * @param dim
	 * @return
	 */
	public double dist(double[] pa, double[] pb, int dim) {
		double rv = 0;
		for (int i = 0; i < dim; i++) {
			double temp = pa[i] - pb[i];
			temp = temp * temp;
			rv += temp;
		}
		return Math.sqrt(rv);
	}

	/**
	 * 样本训练,需要人为设定k值(聚类中心数目)
	 * @param k
	 * @param data
	 * @return
	 * @throws Exception
	 */
	public ClusterModel train(String path, int k) throws Exception {
		double[][] matrix = fileToMatrix(path);
		data = new Kmeans_data(matrix, matrix.length, matrix[0].length);
		return train(k, new Kmeans_param());
	}

	/**
	 * 样本训练(系统默认最优聚类中心数目)
	 * @param data
	 * @return
	 * @throws Exception
	 */
	public ClusterModel train(String path) throws Exception {
		double[][] matrix = fileToMatrix(path);
		data = new Kmeans_data(matrix, matrix.length, matrix[0].length);
		return train(new Kmeans_param());
	}
	
	private ClusterModel train(Kmeans_param param) {
		int k = Kmeans_param.K;
		// 首先进行数据归一化处理
		normalize(data);
		// 计算第一个样本和后面的所有样本的欧氏距离,存入list中然后计算均值,作为聚类中心选取的依据
		List<Double> dists = new ArrayList<Double>();
		for (int i = 1; i < data.length; i++) {
			dists.add(dist(data.data[0], data.data[i], data.dim));
		}
		param.min_euclideanDistance = Double.valueOf(df.format((Collections.max(dists) + Collections.min(dists)) / 2));
		double euclideanDistance = param.min_euclideanDistance > 0 ? param.min_euclideanDistance
				: Kmeans_param.MIN_EuclideanDistance;
		
		int centerIndexes[] = new int[k];// 收集聚类中心索引的数组
		int countCenter = 0;// 动态表示中心的数目
		int count = 0;// 计数器
		centerIndexes[0] = 0;
		countCenter++;
		for (int i = 1; i < data.length; i++) {
			for (int j = 0; j < countCenter; j++) {
				if (dist(data.data[i], data.data[centerIndexes[j]], data.dim) > euclideanDistance) {
					count++;
				}
			}
			if (count == countCenter) {
				centerIndexes[countCenter++] = i;
			}
			count = 0;
		}
		
		double[][] centers = new double[countCenter][data.dim]; // 聚类中心
		data.centers = centers;
		int[] centerCounts = new int[countCenter]; // 聚类中心的样本个数
		model.centerCounts = centerCounts;
		Arrays.fill(centerCounts, 0);
		int[] labels = new int[data.length]; // 样本的类别
		model.labels = labels;
		double[][] oldCenters = new double[countCenter][data.dim]; // 存储旧的聚类中心

		// 给聚类中心赋值
		for (int i = 0; i < countCenter; i++) {
			int m = centerIndexes[i];
			for (int j = 0; j < data.dim; j++) {
				centers[i][j] = data.data[m][j];
			}
		}

		// 给最初始的聚类中心赋值
		model.originalCenters = new double[countCenter][data.dim];
		for (int i = 0; i < countCenter; i++) {
			for (int j = 0; j < data.dim; j++) {
				model.originalCenters[i][j] = centers[i][j];
			}
		}

		//初始聚类
		for (int i = 0; i < data.length; i++) {
			double minDist = dist(data.data[i], centers[0], data.dim);
			int label = 0;
			for (int j = 1; j < countCenter; j++) {
				double tempDist = dist(data.data[i], centers[j], data.dim);
				if (tempDist < minDist) {
					minDist = tempDist;
					label = j;
				}
			}
			labels[i] = label;
			centerCounts[label]++;
		}
		updateCenters(countCenter, data);
		copyCenters(oldCenters, centers, countCenter, data.dim);

		// 迭代预处理
		int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS;
		int attempts = 1;
		double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA;
		double criteriaBreakCondition = 0;
		boolean[] flags = new boolean[k]; // 用来表示聚类中心是否发生变化

		// 迭代
		iterate: while (attempts < maxAttempts) { // 迭代次数不超过最大值,最大中心改变量不超过阈值
			for (int i = 0; i < countCenter; i++) { //  初始化中心点"是否被修改过"标记
				flags[i] = false;
			}
			for (int i = 0; i < data.length; i++) { 
				double minDist = dist(data.data[i], centers[0], data.dim);
				int label = 0;
				for (int j = 1; j < countCenter; j++) {
					double tempDist = dist(data.data[i], centers[j], data.dim);
					if (tempDist < minDist) {
						minDist = tempDist;
						label = j;
					}
				}
				if (label != labels[i]) { // 如果当前点被聚类到新的类别则做更新
					int oldLabel = labels[i];
					labels[i] = label;
					centerCounts[oldLabel]--;
					centerCounts[label]++;
					flags[oldLabel] = true;
					flags[label] = true;
				}
			}
			updateCenters(countCenter, data);
			attempts++;

			// 计算被修改过的中心点最大修改量是否超过阈值
			double maxDist = 0;
			for (int i = 0; i < countCenter; i++) {
				if (flags[i]) {
					double tempDist = dist(centers[i], oldCenters[i], data.dim);
					if (maxDist < tempDist) {
						maxDist = tempDist;
					}
					for (int j = 0; j < data.dim; j++) { // 更新oldCenter
						oldCenters[i][j] = centers[i][j];
						oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j]));
					}
				}
			}
			if (maxDist < criteria) {
				criteriaBreakCondition = maxDist;
				break iterate;
			}
		}
		// 把结果存储到ClusterModel中
		ClusterModel rvInfo = outputClusterInfo(criteriaBreakCondition, countCenter, attempts, param, centerCounts);
		return rvInfo;
	}

	private ClusterModel train(int k, Kmeans_param param) {
		// 首先进行数据归一化处理
		normalize(data);
		
		List<Double> dists = new ArrayList<Double>();
		for (int i = 1; i < data.length; i++) {
			dists.add(dist(data.data[0], data.data[i], data.dim));
		}

		param.min_euclideanDistance = Double.valueOf(df.format((Collections.max(dists) + Collections.min(dists)) / 2));
		double euclideanDistance = param.min_euclideanDistance > 0 ? param.min_euclideanDistance
				: Kmeans_param.MIN_EuclideanDistance;

		
		double[][] centers = new double[k][data.dim]; 
		data.centers = centers;
		int[] centerCounts = new int[k]; 
		model.centerCounts = centerCounts;
		Arrays.fill(centerCounts, 0);
		int[] labels = new int[data.length]; 
		model.labels = labels;
		double[][] oldCenters = new double[k][data.dim]; 

		
		int centerIndexes[] = new int[k];
		int countCenter = 0;
		int count = 0;
		centerIndexes[0] = 0;
		countCenter++;
		for (int i = 1; i < data.length; i++) {
			for (int j = 0; j < countCenter; j++) {
				if (dist(data.data[i], data.data[centerIndexes[j]], data.dim) > euclideanDistance) {
					count++;
				}
			}
			if (count == countCenter) {
				centerIndexes[countCenter++] = i;
			}
			count = 0;
			
			if (countCenter == k) {
				break;
			}
			
			if (countCenter < k && i == data.length - 1) {
				k = countCenter;
				break;
			}
		}
		
		for (int i = 0; i < k; i++) {
			int m = centerIndexes[i];
			for (int j = 0; j < data.dim; j++) {
				centers[i][j] = data.data[m][j];
			}
		}

		
		model.originalCenters = new double[k][data.dim];
		for (int i = 0; i < k; i++) {
			for (int j = 0; j < data.dim; j++) {
				model.originalCenters[i][j] = centers[i][j];
			}
		}

		
		for (int i = 0; i < data.length; i++) {
			double minDist = dist(data.data[i], centers[0], data.dim);
			int label = 0;
			for (int j = 1; j < k; j++) {
				double tempDist = dist(data.data[i], centers[j], data.dim);
				if (tempDist < minDist) {
					minDist = tempDist;
					label = j;
				}
			}
			labels[i] = label;
			centerCounts[label]++;
		}
		updateCenters(k, data);
		copyCenters(oldCenters, centers, k, data.dim);
		
		int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS;
		int attempts = 1;
		double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA;
		double criteriaBreakCondition = 0;
		boolean[] flags = new boolean[k];
		
		iterate: while (attempts < maxAttempts) { 
			for (int i = 0; i < k; i++) {
				flags[i] = false;
			}
			for (int i = 0; i < data.length; i++) {
				double minDist = dist(data.data[i], centers[0], data.dim);
				int label = 0;
				for (int j = 1; j < k; j++) {
					double tempDist = dist(data.data[i], centers[j], data.dim);
					if (tempDist < minDist) {
						minDist = tempDist;
						label = j;
					}
				}
				if (label != labels[i]) {
					int oldLabel = labels[i];
					labels[i] = label;
					centerCounts[oldLabel]--;
					centerCounts[label]++;
					flags[oldLabel] = true;
					flags[label] = true;
				}
			}
			updateCenters(k, data);
			attempts++;
			
			double maxDist = 0;
			for (int i = 0; i < k; i++) {
				if (flags[i]) {
					double tempDist = dist(centers[i], oldCenters[i], data.dim);
					if (maxDist < tempDist) {
						maxDist = tempDist;
					}
					for (int j = 0; j < data.dim; j++) { // 锟斤拷锟斤拷oldCenter
						oldCenters[i][j] = centers[i][j];
						oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j]));
					}
				}
			}
			if (maxDist < criteria) {
				criteriaBreakCondition = maxDist;
				break iterate;
			}
		}
	
		ClusterModel rvInfo = outputClusterInfo(criteriaBreakCondition, k, attempts, param, centerCounts);
		return rvInfo;
	}

	/**
	 * 把聚类结果存储到Model中 
	 * @param criteriaBreakCondition
	 * @param k
	 * @param attempts
	 * @param param
	 * @param centerCounts
	 * @return
	 */
	private ClusterModel outputClusterInfo(double criteriaBreakCondition, int k, int attempts, Kmeans_param param,
			int[] centerCounts) {		
		model.data = data;
		model.k = k;
		int perm[] = new int[data.length];
		model.perm = perm;
		int start[] = new int[k];
		model.start = start;
		group_class(perm, start, k, data);
		return model;
	}

	/**
	 * 把聚类样本按所属类别连续存储
	 * @param perm
	 * @param start
	 * @param k
	 * @param data
	 */
	private void group_class(int perm[], int start[], int k, Kmeans_data data) {
		
		start[0] = 0;
		for (int i = 1; i < k; i++) {
			start[i] = start[i - 1] + model.centerCounts[i - 1];
		}		
		
		for (int i = 0; i < data.length; i++) {			
			perm[start[model.labels[i]]++] = i;
		}

		start[0] = 0;
		for (int i = 1; i < k; i++) {
			start[i] = start[i - 1] + model.centerCounts[i - 1];
		}
	}

	/**
	 * 数据归一化处理 
	 * @param data
	 * @author TongXueQiang
	 */
	private void normalize(Kmeans_data data) {		
		Map<Integer, Double[]> minAndMax = new HashMap<Integer, Double[]>();
		for (int i = 0; i < data.dim; i++) {
			Double[] nums = new Double[2];
			double max = data.data[0][i];
			double min = data.data[data.length - 1][i];
			for (int j = 0; j < data.length; j++) {
				if (data.data[j][i] > max) {
					max = data.data[j][i];
				}
				if (data.data[j][i] < min) {
					min = data.data[j][i];
				}
			}
			nums[0] = min;
			nums[1] = max;
			minAndMax.put(i, nums);
		}		
		for (int i = 0; i < data.length; i++) {
			for (int j = 0; j < data.dim; j++) {
				double minValue = minAndMax.get(j)[0];
				double maxValue = minAndMax.get(j)[1];
				data.data[i][j] = (data.data[i][j] - minValue) / (maxValue - minValue);
				data.data[i][j] = Double.valueOf(df.format(data.data[i][j]));
			}
		}
	}
}

 测试代码:

package com.txq.kmeans.test;

import org.junit.Test;
import com.txq.kmeans.ClusterModel;
import com.txq.kmeans.Kmeans;
/**
 * 
 * @author XueQiang Tong
 * train方法有两种,一个不需要传递K值,算法内部自动处理为最优值,此为最细粒度聚类,另一个需要传递K值,k值大小任意,当k值>算法内部最优值时,自动调整 为最优值
 * 利用model预测时,只需传递feature标识  
 */
public class KmeansTest {	
	@Test
	public void test() throws Exception {		
		Kmeans kmeans = new Kmeans();
		String path = "F:\\kmeans.txt";
		ClusterModel model = kmeans.train(path);	
		model.centers();
		System.out.println("中国属于第" + (model.predict("中国")+1)+"类");
		model.outputAllResult();
		System.out.println("-------------------------------------------------------------------------------------");
		model = kmeans.train(path,100000);
		model.centers();
		System.out.println("中国属于第" + (model.predict("中国")+1)+"类");
		model.outputAllResult();		
	}	
}

 看一下输出结果:

聚类中心:
1.0 1.0 0.5     第1类:    样本个数:10
0.33 0.0 0.19     第2类:    样本个数:2
0.24 0.76 0.25     第3类:    样本个数:2
0.7 0.56 1.0     第4类:    样本个数:1
中国属于第1类

最后聚类结果--------------------------->
第1类成员:
1.0 1.0 0.5 :中国
1.0 1.0 0.0 :伊拉克
1.0 0.76 0.5 :卡塔尔
1.0 0.76 0.5 :阿联酋
0.7 0.76 0.25 :乌兹别克斯坦
1.0 1.0 0.5 :泰国
1.0 1.0 0.25 :越南
1.0 1.0 0.5 :阿曼
0.7 0.76 0.5 :巴林
1.0 1.0 0.5 :印尼
第2类成员:
0.33 0.0 0.19 :日本
0.0 0.15 0.12 :韩国
第3类成员:
0.24 0.76 0.25 :伊朗
0.33 0.76 0.06 :沙特
第4类成员:
0.7 0.56 1.0 :朝鲜
-------------------------------------------------------------------------------------
聚类中心:
1.0 1.0 0.5     第1类:    样本个数:10
0.33 0.0 0.19     第2类:    样本个数:2
0.24 0.76 0.25     第3类:    样本个数:2
0.7 0.56 1.0     第4类:    样本个数:1
中国属于第1类

最后聚类结果--------------------------->
第1类成员:
1.0 1.0 0.5 :中国
1.0 1.0 0.0 :伊拉克
1.0 0.76 0.5 :卡塔尔
1.0 0.76 0.5 :阿联酋
0.7 0.76 0.25 :乌兹别克斯坦
1.0 1.0 0.5 :泰国
1.0 1.0 0.25 :越南
1.0 1.0 0.5 :阿曼
0.7 0.76 0.5 :巴林
1.0 1.0 0.5 :印尼
第2类成员:
0.33 0.0 0.19 :日本
0.0 0.15 0.12 :韩国
第3类成员:
0.24 0.76 0.25 :伊朗
0.33 0.76 0.06 :沙特
第4类成员:
0.7 0.56 1.0 :朝鲜

现在更改一下k值,设为3,看看效果:

聚类中心:
1.0 1.0 0.5     第1类:    样本个数:10
0.33 0.0 0.19     第2类:    样本个数:2
0.24 0.76 0.25     第3类:    样本个数:2
0.7 0.56 1.0     第4类:    样本个数:1
中国属于第1类

最后聚类结果--------------------------->
第1类成员:
1.0 1.0 0.5 :中国
1.0 1.0 0.0 :伊拉克
1.0 0.76 0.5 :卡塔尔
1.0 0.76 0.5 :阿联酋
0.7 0.76 0.25 :乌兹别克斯坦
1.0 1.0 0.5 :泰国
1.0 1.0 0.25 :越南
1.0 1.0 0.5 :阿曼
0.7 0.76 0.5 :巴林
1.0 1.0 0.5 :印尼
第2类成员:
0.33 0.0 0.19 :日本
0.0 0.15 0.12 :韩国
第3类成员:
0.24 0.76 0.25 :伊朗
0.33 0.76 0.06 :沙特
第4类成员:
0.7 0.56 1.0 :朝鲜
-------------------------------------------------------------------------------------
聚类中心:
1.0 1.0 0.5     第1类:    样本个数:11
0.33 0.0 0.19     第2类:    样本个数:2
0.24 0.76 0.25     第3类:    样本个数:2
中国属于第1类

最后聚类结果--------------------------->
第1类成员:
1.0 1.0 0.5 :中国
1.0 1.0 0.0 :伊拉克
1.0 0.76 0.5 :卡塔尔
1.0 0.76 0.5 :阿联酋
0.7 0.76 0.25 :乌兹别克斯坦
1.0 1.0 0.5 :泰国
1.0 1.0 0.25 :越南
1.0 1.0 0.5 :阿曼
0.7 0.76 0.5 :巴林
0.7 0.56 1.0 :朝鲜
1.0 1.0 0.5 :印尼
第2类成员:
0.33 0.0 0.19 :日本
0.0 0.15 0.12 :韩国
第3类成员:
0.24 0.76 0.25 :伊朗
0.33 0.76 0.06 :沙特

现在评估一下算法的准确度:

数据分析:以下数据为欧氏距离对比,其中簇中心为中国,日本,朝鲜和伊朗,分别代表了4个梯队。欧氏距离的均值为0.68。

中国-朝鲜:0.730479294709987
巴林-朝鲜:0.5385164807134504
巴林-中国:0.38418745424597095

上述数据中,中国-朝鲜是准确的,属于不同类别。巴林与簇中心朝鲜的欧氏距离大于与中国的距离,所以聚类的时候,与中国是一类。由于是按顺序扫描,降低了不确定性和时间复杂度。如果训练数据顺序调整了,选取了巴林作为簇中心的话,虽然从算法上看是准确的,但是,效果并不是最好的。这个算法的缺点是,对训练数据的顺序比较敏感。但是,总体情况,此算法的准确率非常高,而且聚类结果是稳定的,并且与原来相比,降低了一个数量级的时间复杂度,可以满足实际工程需要。

更多精彩博客推荐,语义相似度经典:http://www.cnblogs.com/txq157/p/7425781.html

 1 package com.txq.kmeans;
 2 
 3 /**
 4  * 
 5  * @author TongXueQiang
 6  * @param data 原始矩阵
 7  * @param labels 所属类别
 8  * @param centers 簇中心 
 9  */
10 public class Kmeans_data {
11     public double[][] data;
12     public int length;
13     public int dim;    
14     public double[][] centers;    
15     
16     public Kmeans_data(double[][] data, int length, int dim) {
17         this.data = data;
18         this.length = length;
19         this.dim = dim;        
20     }
21 
  1 package com.txq.kmeans;
  2 
  3 import java.io.BufferedReader;
  4 import java.io.FileReader;
  5 import java.text.DecimalFormat;
  6 import java.util.ArrayList;
  7 import java.util.Arrays;
  8 import java.util.Collections;
  9 import java.util.HashMap;
 10 import java.util.HashSet;
 11 import java.util.List;
 12 import java.util.Map;
 13 import java.util.Random;
 14 import java.util.Set;
 15 
 16 /**
 17  * Kmeans聚类算法
 18  * 
 19  * @author TongXueQiang
 20  * @date 2016/11/09
 21  */
 22 public class Kmeans {
 23     private DecimalFormat df = new DecimalFormat("#####.00");
 24     public Kmeans_data data = null;
 25     //feature身份标识与索引的映射
 26     private Map<String,Integer> identifier = new HashMap<String,Integer>();    
 27     private Map<Integer,String> iden0 = new HashMap<Integer,String>();
 28     private ClusterModel model = new ClusterModel();
 29     
 30     /**
 31      * 文件到矩阵的映射
 32      * @param path
 33      * @return
 34      * @throws Exception
 35      */
 36     public double [][] fileToMatrix(String path) throws Exception{
 37         List<String> contents = new ArrayList<String>();        
 38         model.identifier = identifier;
 39         model.iden0 = iden0;
 40         
 41         BufferedReader bf = new BufferedReader(new FileReader(path));
 42         String str = null;        
 43         int rows = 0;
 44         int dim = 0;
 45         
 46         while((str = bf.readLine()) != null) {
 47             contents.add(str);
 48             ++rows;
 49         }        
 50         bf.close();        
 51         String []strs = contents.get(0).split(":");    
 52         dim = strs[0].split(" ").length;
 53         
 54         double [][]da = new double[rows][dim];
 55         
 56         for(int j = 0;j < contents.size();j++){
 57             strs = contents.get(j).split(":");
 58             identifier.put(strs[1],j);
 59             iden0.put(j,strs[1]);
 60             String []feature = strs[0].split(" ");
 61             for(int i = 0;i < dim;i++){                
 62                 da[j][i] = Double.parseDouble(feature[i]);
 63             }                        
 64         }
 65         
 66         return da;
 67     }
 68 
 69     /**
 70      * double[][] 元素全置
 71      * 
 72      * @param matrix
 73      *            double[][]
 74      * @param highDim
 75      *            int
 76      * @param lowDim
 77      *            int <br/>
 78      *            double[highDim][lowDim]
 79      */
 80     private void setDouble2Zero(double[][] matrix, int highDim, int lowDim) {
 81         for (int i = 0; i < highDim; i++) {
 82             for (int j = 0; j < lowDim; j++) {
 83                 matrix[i][j] = 0;
 84             }
 85         }
 86     }
 87 
 88     /**
 89      * 拷贝源二维矩阵元素到目标二维矩阵。 foreach (dests[highDim][lowDim] =
 90      * sources[highDim][lowDim]);
 91      * 
 92      * @param dests
 93      *            double[][]
 94      * @param sources
 95      *            double[][]
 96      * @param highDim
 97      *            int
 98      * @param lowDim
 99      *            int
100      */
101     private void copyCenters(double[][] dests, double[][] sources, int highDim, int lowDim) {
102         for (int i = 0; i < highDim; i++) {
103             for (int j = 0; j < lowDim; j++) {
104                 dests[i][j] = sources[i][j];
105             }
106         }
107     }
108 
109     /**
110      * 更新聚类中心坐标
111      * 
112      * @param k
113      *            int 分类个数
114      * @param data
115      *            kmeans_data
116      */
117     private void updateCenters(int k, Kmeans_data data) {
118         double[][] centers = data.centers;
119         setDouble2Zero(centers, k, data.dim);// 归零处理
120         int[] labels = model.labels;
121         int[] centerCounts = model.centerCounts;
122         for (int i = 0; i < data.dim; i++) {
123             for (int j = 0; j < data.length; j++) {
124                 centers[labels[j]][i] += data.data[j][i];
125             }
126         }
127         for (int i = 0; i < k; i++) {
128             for (int j = 0; j < data.dim; j++) {
129                 centers[i][j] = centers[i][j] / centerCounts[i];
130                 // centers[i][j] =
131                 // Double.parseDouble(df.format(centers[i][j]).toString());
132             }
133         }
134     }
135 
136     /**
137      * 计算两点欧氏距离
138      * 
139      * @param pa
140      *            double[]
141      * @param pb
142      *            double[]
143      * @param dim
144      *            int 维数
145      * @return double 距离
146      */
147     public double dist(double[] pa, double[] pb, int dim) {
148         double rv = 0;
149         for (int i = 0; i < dim; i++) {
150             double temp = pa[i] - pb[i];
151             temp = temp * temp;
152             rv += temp;
153         }
154         return Math.sqrt(rv);
155     }
156 
157     /**
158      * 外部调用有k值的聚类方法,非最优解
159      * 
160      * @param k
161      * @param data
162      * @return
163      * @throws Exception 
164      */
165     public ClusterModel train(String path,int k) throws Exception {
166         double [][]matrix = fileToMatrix(path);
167         data = new Kmeans_data(matrix, matrix.length, matrix[0].length);
168         return train(k, new Kmeans_param());
169     }
170 
171     /**
172      * 外部调用无k值的聚类方法,最优解
173      * 
174      * @param data
175      * @return
176      * @throws Exception 
177      */
178     public ClusterModel train(String path) throws Exception {
179         double [][]matrix = fileToMatrix(path);//文件到矩阵的映射
180         data = new Kmeans_data(matrix, matrix.length, matrix[0].length);
181         return train(new Kmeans_param());
182     }
183 
184     private ClusterModel train(Kmeans_param param) {
185         int k = param.K;
186         // 对数据进行规一化处理,以消除大的数据的影响
187         normalize(data);
188         
189         // 寻找欧氏距离的均值
190         List<Double> dists = new ArrayList<Double>();
191         for (int i = 1; i < data.length; i++) {            
192             dists.add(dist(data.data[0], data.data[i], data.dim));
193             
194         }        
195         param.min_euclideanDistance = Double.valueOf(df.format((Collections.max(dists)+Collections.min(dists))/2));
196         double euclideanDistance = param.min_euclideanDistance > 0 ? param.min_euclideanDistance
197                 : Kmeans_param.MIN_EuclideanDistance;
198         
199         // 预处理
200         double[][] centers = new double[k][data.dim]; // 聚类中心点集
201         data.centers = centers;
202         int[] centerCounts = new int[k]; // 各聚类的包含点个数
203         model.centerCounts = centerCounts;
204         Arrays.fill(centerCounts, 0);
205         int[] labels = new int[data.length]; // 各个点所属聚类标号
206         model.labels = labels;
207         double[][] oldCenters = new double[k][data.dim]; // 临时缓存旧的聚类中心坐标
208 
209         // 初始化聚类中心
210         int centerIndexes[] = new int[16];// 预初始化16个簇组中心
211         int countCenter = 0;// 动态表示簇中心个数
212         int count = 0;// 计数器
213         centerIndexes[0] = 0;
214         countCenter++;
215         for (int i = 1; i < data.length; i++) {
216             for (int j = 0; j < countCenter; j++) {
217                 if (dist(data.data[i], data.data[centerIndexes[j]], data.dim) > euclideanDistance) {
218                     count++;
219                 }
220             }
221             if (count == countCenter) {
222                 centerIndexes[countCenter++] = i;
223             }
224             count = 0;// 计数器清零
225             // 如果达到了k值,提前终止
226             if (countCenter == k) {
227                 break;
228             }
229             // 如果遍历了整个数据,仍然没有找到合适的中心点的话,把k自动降低为countCeneter,使簇中心个数更加趋于合理化
230             if (countCenter < k && i == data.length - 1) {
231                 k = countCenter;
232                 break;
233             }
234         }
235         // 给centers赋值
236         for (int i = 0; i < k; i++) {
237             int m = centerIndexes[i];
238             for (int j = 0; j < data.dim; j++) {
239                 centers[i][j] = data.data[m][j];
240             }
241         }
242 
243         // 给最初的聚类中心赋值
244         model.originalCenters = new double[k][data.dim];
245         for (int i = 0; i < k; i++) {
246             for (int j = 0; j < data.dim; j++) {
247                 model.originalCenters[i][j] = centers[i][j];
248             }
249         }
250 
251         // 第一轮迭代
252         for (int i = 0; i < data.length; i++) {
253             double minDist = dist(data.data[i], centers[0], data.dim);
254             int label = 0;
255             for (int j = 1; j < k; j++) {
256                 double tempDist = dist(data.data[i], centers[j], data.dim);
257                 if (tempDist < minDist) {
258                     minDist = tempDist;
259                     label = j;
260                 }
261             }
262             labels[i] = label;
263             centerCounts[label]++;
264         }
265         updateCenters(k, data);// 更新簇中心
266         copyCenters(oldCenters, centers, k, data.dim);
267 
268         // 迭代预处理
269         int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS;
270         int attempts = 1;
271         double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA;
272         double criteriaBreakCondition = 0;
273         boolean[] flags = new boolean[k]; // 标记哪些中心被修改过
274 
275         // 迭代
276         iterate: while (attempts < maxAttempts) { // 迭代次数不超过最大值,最大中心改变量不超过阈值
277             for (int i = 0; i < k; i++) { // 初始化中心点“是否被修改过”标记
278                 flags[i] = false;
279             }
280             for (int i = 0; i < data.length; i++) { // 遍历data内所有点
281                 double minDist = dist(data.data[i], centers[0], data.dim);
282                 int label = 0;
283                 for (int j = 1; j < k; j++) {
284                     double tempDist = dist(data.data[i], centers[j], data.dim);
285                     if (tempDist < minDist) {
286                         minDist = tempDist;
287                         label = j;
288                     }
289                 }
290                 if (label != labels[i]) { // 如果当前点被聚类到新的类别则做更新
291                     int oldLabel = labels[i];
292                     labels[i] = label;
293                     centerCounts[oldLabel]--;
294                     centerCounts[label]++;
295                     flags[oldLabel] = true;
296                     flags[label] = true;
297                 }
298             }
299             updateCenters(k, data);
300             attempts++;
301 
302             // 计算被修改过的中心点最大修改量是否超过阈值
303             double maxDist = 0;
304             for (int i = 0; i < k; i++) {
305                 if (flags[i]) {
306                     double tempDist = dist(centers[i], oldCenters[i], data.dim);
307                     if (maxDist < tempDist) {
308                         maxDist = tempDist;
309                     }
310                     for (int j = 0; j < data.dim; j++) { // 更新oldCenter
311                         oldCenters[i][j] = centers[i][j];
312                         oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j]));
313                     }
314                 }
315             }
316             if (maxDist < criteria) {
317                 criteriaBreakCondition = maxDist;
318                 break iterate;
319             }
320         }
321         //输出训练模型
322         ClusterModel rvInfo = outputClusterInfo(criteriaBreakCondition, k, attempts, param, centerCounts);
323         return rvInfo;
324     }
325 
326     /**
327      * 做Kmeans运算,需要手动设置K值
328      * 
329      * @param k
330      *            int 聚类个数
331      * @param data
332      *            kmeans_data kmeans数据类
333      * @param param
334      *            kmeans_param kmeans参数类
335      * @return kmeans_result kmeans运行信息类
336      */
337     private ClusterModel train(int k, Kmeans_param param) {
338         // 对数据进行规一化处理,以消除大的数据的影响
339         normalize(data);
340         
341         // 寻找欧氏距离的均值
342         List<Double> dists = new ArrayList<Double>();
343         for (int i = 1; i < data.length; i++) {            
344             dists.add(dist(data.data[0], data.data[i], data.dim));            
345         }        
346         
347         param.min_euclideanDistance = Double.valueOf(df.format((Collections.max(dists)+Collections.min(dists))/2));
348         double euclideanDistance = param.min_euclideanDistance > 0 ? param.min_euclideanDistance
349                 : Kmeans_param.MIN_EuclideanDistance;
350         
351         // 预处理
352         double[][] centers = new double[k][data.dim]; // 聚类中心点集
353         data.centers = centers;
354         int[] centerCounts = new int[k]; // 各聚类的包含点个数
355         model.centerCounts = centerCounts;
356         Arrays.fill(centerCounts, 0);
357         int[] labels = new int[data.length]; // 各个点所属聚类标号
358         model.labels = labels;
359         double[][] oldCenters = new double[k][data.dim]; // 临时缓存旧的聚类中心坐标
360 
361         // 初始化聚类中心(依序选择data内的k个不重复点)
362         int centerIndexes[] = new int[16];// 预初始化16个簇组中心
363         int countCenter = 0;// 动态表示簇中心个数
364         int count = 0;// 计数器
365         centerIndexes[0] = 0;
366         countCenter++;
367         for (int i = 1; i < data.length; i++) {
368             for (int j = 0; j < countCenter; j++) {
369                 if (dist(data.data[i], data.data[centerIndexes[j]], data.dim) > euclideanDistance) {
370                     count++;
371                 }
372             }
373             if (count == countCenter) {
374                 centerIndexes[countCenter++] = i;
375             }
376             count = 0;// 计数器清零
377             // 如果达到了k值,提前终止
378             if (countCenter == k) {
379                 break;
380             }
381             // 如果遍历了整个数据,仍然没有找到合适的中心点的话,把k自动降低为countCeneter,使簇中心个数更加趋于合理化
382             if (countCenter < k && i == data.length - 1) {
383                 k = countCenter;
384                 break;
385             }
386         }
387         // 给centers赋值
388         for (int i = 0; i < k; i++) {
389             int m = centerIndexes[i];
390             for (int j = 0; j < data.dim; j++) {
391                 centers[i][j] = data.data[m][j];
392             }
393         }
394 
395         // 给最初的聚类中心赋值
396         model.originalCenters = new double[k][data.dim];
397         for (int i = 0; i < k; i++) {
398             for (int j = 0; j < data.dim; j++) {
399                 model.originalCenters[i][j] = centers[i][j];
400             }
401         }
402 
403         // 第一轮迭代
404         for (int i = 0; i < data.length; i++) {
405             double minDist = dist(data.data[i], centers[0], data.dim);
406             int label = 0;
407             for (int j = 1; j < k; j++) {
408                 double tempDist = dist(data.data[i], centers[j], data.dim);
409                 if (tempDist < minDist) {
410                     minDist = tempDist;
411                     label = j;
412                 }
413             }
414             labels[i] = label;
415             centerCounts[label]++;
416         }
417         updateCenters(k, data);// 更新簇中心
418         copyCenters(oldCenters, centers, k, data.dim);
419 
420         // 迭代预处理
421         int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS;
422         int attempts = 1;
423         double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA;
424         double criteriaBreakCondition = 0;
425         boolean[] flags = new boolean[k]; // 标记哪些中心被修改过
426 
427         // 迭代
428         iterate: while (attempts < maxAttempts) { // 迭代次数不超过最大值,最大中心改变量不超过阈值
429             for (int i = 0; i < k; i++) { // 初始化中心点"是否被修改过"标记
430                 flags[i] = false;
431             }
432             for (int i = 0; i < data.length; i++) { // 遍历data内所有点
433                 double minDist = dist(data.data[i], centers[0], data.dim);
434                 int label = 0;
435                 for (int j = 1; j < k; j++) {
436                     double tempDist = dist(data.data[i], centers[j], data.dim);
437                     if (tempDist < minDist) {
438                         minDist = tempDist;
439                         label = j;
440                     }
441                 }
442                 if (label != labels[i]) { // 如果当前点被聚类到新的类别则做更新
443                     int oldLabel = labels[i];
444                     labels[i] = label;
445                     centerCounts[oldLabel]--;
446                     centerCounts[label]++;
447                     flags[oldLabel] = true;
448                     flags[label] = true;
449                 }
450             }
451             updateCenters(k, data);
452             attempts++;
453 
454             // 计算被修改过的中心点最大修改量是否超过阈值
455             double maxDist = 0;
456             for (int i = 0; i < k; i++) {
457                 if (flags[i]) {
458                     double tempDist = dist(centers[i], oldCenters[i], data.dim);
459                     if (maxDist < tempDist) {
460                         maxDist = tempDist;
461                     }
462                     for (int j = 0; j < data.dim; j++) { // 更新oldCenter
463                         oldCenters[i][j] = centers[i][j];
464                         oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j]));
465                     }
466                 }
467             }
468             if (maxDist < criteria) {
469                 criteriaBreakCondition = maxDist;
470                 break iterate;
471             }
472         }
473 
474         // 输出信息,把属于同一类的数据连续存放
475         ClusterModel rvInfo = outputClusterInfo(criteriaBreakCondition, k, attempts, param, centerCounts);
476         return rvInfo;
477     }
478 
479     /**
480      * 输出聚类结果
481      * 
482      * @param criteriaBreakCondition
483      * @param k
484      * @param attempts
485      * @param param
486      * @param centerCounts
487      * @return
488      */
489     private ClusterModel outputClusterInfo(double criteriaBreakCondition, int k, int attempts, Kmeans_param param,
490             int[] centerCounts) {
491         // 输出信息,把属于同一类的数据连续存放    
492         model.data = data;
493         model.k = k;
494         int perm[] = new int[data.length];
495         model.perm = perm;
496         int start[] = new int[k];
497         model.start = start;
498         group_class(perm, start, k, data);        
499         return model;
500     }
501 
502     /**
503      * @author TongXueQiang
504      * @param perm
505      *            连续存放归类后的原始数据的索引
506      * @param start
507      *            每个类的起始索引位置
508      * @param k
509      *            聚类中心个数
510      * @param data
511      *            原始数据---二维矩阵
512      */
513     private void group_class(int perm[], int start[], int k, Kmeans_data data) {
514         start[0] = 0;
515         for (int i = 1; i < k; i++) {
516             start[i] = start[i - 1] + model.centerCounts[i - 1];
517         }
518 
519         for (int i = 0; i < data.length; i++) {
520             perm[start[model.labels[i]]++] = i;
521         }
522 
523         start[0] = 0;
524         for (int i = 1; i < k; i++) {
525             start[i] = start[i - 1] + model.centerCounts[i - 1];
526         }
527     }
528 
529     /**
530      * 规一化处理
531      * 
532      * @param data
533      * @author TongXueQiang
534      */
535     private void normalize(Kmeans_data data) {
536         // 1.首先计算各个列的最大和最小值,存入map中
537         Map<Integer, Double[]> minAndMax = new HashMap<Integer, Double[]>();
538         for (int i = 0; i < data.dim; i++) {
539             Double[] nums = new Double[2];
540             double max = data.data[0][i];
541             double min = data.data[data.length - 1][i];
542             for (int j = 0; j < data.length; j++) {
543                 if (data.data[j][i] > max) {
544                     max = data.data[j][i];
545                 }
546                 if (data.data[j][i] < min) {
547                     min = data.data[j][i];
548                 }
549             }
550             nums[0] = min;
551             nums[1] = max;
552             minAndMax.put(i, nums);
553         }
554         // 2.更新矩阵的值
555         for (int i = 0; i < data.length; i++) {
556             for (int j = 0; j < data.dim; j++) {
557                 double minValue = minAndMax.get(j)[0];
558                 double maxValue = minAndMax.get(j)[1];
559                 data.data[i][j] = (data.data[i][j] - minValue) / (maxValue - minValue);
560                 data.data[i][j] = Double.valueOf(df.format(data.data[i][j]));
561             }
562         }
563     }
564 
posted @ 2016-11-15 20:05  佟学强  阅读(10718)  评论(3编辑  收藏  举报