原创:Kmeans算法实战+改进(java实现)
EM思想很伟大,在处理含有隐式变量的机器学习算法中很有用。聚类算法包括kmeans,高斯混合聚类,快速迭代聚类等等,都离不开EM思想。在了解kmeans算法之前,有必要详细了解一下EM思想。
Kmeans算法属于无监督学习中的一种,相比于监督学习,能节省很多成本,省去了大量的标签标注。每个数据点都有自己的隐式的分类。聚类要做的是,从中选取出数个聚类中心,对数据集进行初始聚类。此后,通过更新聚类中心(把簇中心缓存起来),重新聚类,然后再更新簇中心,如果此簇中心与旧的簇中心的差值(2范数)<阈值,说明聚类趋于稳定,迭代结束。Kmeans算法,通过计算两个数据点的欧氏距离(2范数),来对数据点进行归类。这个算法和高斯混合聚类相比,要死板很多,而且,有一个最重要的弱点,就是聚类结果对初始化的簇中心比较敏感,而且容易陷入局部最优。因为,评价kmeans的损失函数属于非凸函数,不能取得全局最优解。稍后,在代码中,会有说明。如果想改进这个算法,可以考虑半聚类算法,与其他算法结合起来,削弱其弱点 。
关于算法的研究,本人人为,应该从以下三方面着手:第一境界,明白原理,从理论上获得支撑;第二层面,深刻理解算法实现,能够根据高数进行推导,并且找出算法的优劣点;第三层面,能够证明算法的正确性,并且提出改进方案。Kmeans算法是基于EM思想,Kmeans算法的挑战在于如何提高聚类的准确性和稳定性。 在改进上主要朝着上述两个方向努力。改进的时候,首先要提出理论上的支持,在实施上,主要手段围绕着改进簇中心的选取方式以及挖掘出k值的隐式最优值。改进簇中心选取方式的目标就是提高准确率和稳定性,挖掘k值隐式最优解是为了提高聚类的颗粒度,追求最优效果。因为使用算法的人,不一定保证能真正深刻理解算法,并且对于训练数据的内部规律,也不一定清晰。而且Kmeans算法,人为地在外部设置k值,这种做法,本身就存在一定的不合理性。不像监督学习,训练数据的标签,可以按照人的想法进行划分,比如设置3类,或者4类。但是,自动聚类,机器并不能做到人这么智能化。所以,关于k值的设定,有必要改进一下,让机器在一定程度上,自动识别出最优解。这样,在外部调用算法时,当用户设置的k值<隐式最优解的时候,按照k值数目进行聚类,当用户设置的k值很大时,超出了k值的隐式最优解,算法内部应该能够自动调整k值为最优解。这就是方向,有了方向后,就可以沿着这个思路去思考,尝试,测试,直到成功。另外好的算法,从代码层面上看,大都是简单易于执行的,乍一看,就那么几个数据结构。但是能够提出想法,并且从理论上需求突破,这才是最难的。最好的事务,都是很朴实的,使用起来很简单,比如微软提出来的全排列最优算法。
下面,上传本人最近编写的Kmeans算法,这个算法中,有三个地方进行了改进:①增加了数据的归一化处理,以消除大的数据的影响;②增加了数据归类算法,使输出的数据同一类别的,连续存储,使输出结果更加人性化;③使簇中心的选取方式及个数约束更加合理化。追求的效果:一为准确,二为稳定,三消除簇中心的敏感性(实际上,关于这一点永远不能消除,只能最大限度地提升准确率)。
首先,展示一下未改进前的算法:
package com.txq.kmeans; /** * * @param <b>data</b> <i>in double[length][dim]</i><br/>length个instance的坐标,第i(0~length-1)个instance为data[i] * @param <b>length</b> <i>in</i> instance个数 * @param <b>dim</b> <i>in</i> instance维数 * @param <b>labels</b> <i>out int[length]</i><br/>聚类后,instance所属的聚类标号(0~k-1) * @param <b>centers</b> <i>in out double[k][dim]</i><br/>k个聚类中心点的坐标,第i(0~k-1)个中心点为centers[i] * @author Yuanbo She * */ public class Kmeans_data { public double[][] data;//原始矩阵 public int length;//矩阵长度 public int dim;//特征维度 public int[] labels;//数据所属类别的标签,即聚类中心的索引值 public double[][] centers;//聚类中心矩阵 public int[] centerCounts;//每个聚类中心的元素个数 public double [][]originalCenters;//最初的聚类中心坐标点集 public Kmeans_data(double[][] data, int length, int dim) { this.data = data; this.length = length; this.dim = dim; } } 然后,定义聚类所需的参数: public class Kmeans_param { public static final int CENTER_ORDER = 0; public static final int CENTER_RANDOM = 1; public static final int MAX_ATTEMPTS = 4000; public static final double MIN_CRITERIA = 1.0; public static final double MIN_EuclideanDistance = 0.8; public double criteria = MIN_CRITERIA; //阈值 public int attempts = MAX_ATTEMPTS; //尝试次数 public int initCenterMethod = CENTER_RANDOM ; //初始化聚类中心点方式 public boolean isDisplay = true; //是否直接显示结果 public double min_euclideanDistance = MIN_EuclideanDistance; } 还要定义聚类显示的结果: /** * * 聚类显示的结果 * @author TongXueQiang */ public class Kmeans_result { public int attempts; // 退出迭代时的尝试次数 public double criteriaBreakCondition; // 退出迭代时的最大距离(小于阈值) public int k; // 聚类数 public int perm[];//归类后连续存放的原始数据索引 public int start[];//每个类在原始数据中的起始位置 } 接下来,开始聚类: package com.txq.kmeans; import java.text.DecimalFormat; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Random; /** * Kmeans聚类算法 * @author TongXueQiang * @date 2016/11/09 */ public class Kmeans { private static DecimalFormat df = new DecimalFormat("#####.00");//对数据格式化处理 public Kmeans_data data = null; public Kmeans(double [][]da){ data = new Kmeans_data(da,da.length,da[0].length); } /** * double[][] 元素全置 * * @param matrix * double[][] * @param highDim * int * @param lowDim * int <br/> * double[highDim][lowDim] */ private static void setDouble2Zero(double[][] matrix, int highDim, int lowDim) { for (int i = 0; i < highDim; i++) { for (int j = 0; j < lowDim; j++) { matrix[i][j] = 0; } } } /** * 拷贝源二维矩阵元素到目标二维矩阵。 foreach (dests[highDim][lowDim] = * sources[highDim][lowDim]); * * @param dests * double[][] * @param sources * double[][] * @param highDim * int * @param lowDim * int */ private static void copyCenters(double[][] dests, double[][] sources, int highDim, int lowDim) { for (int i = 0; i < highDim; i++) { for (int j = 0; j < lowDim; j++) { dests[i][j] = sources[i][j]; } } } /** * 更新聚类中心坐标,实现思路为:先求簇中心的和,然后求取均值。 * * @param k * int 分类个数 * @param data * kmeans_data */ private static void updateCenters(int k, Kmeans_data data) { double[][] centers = data.centers; setDouble2Zero(centers, k, data.dim);//归零处理 int[] labels = data.labels; int[] centerCounts = data.centerCounts; for (int i = 0; i < data.dim; i++) { for (int j = 0; j < data.length; j++) { centers[labels[j]][i] += data.data[j][i]; } } for (int i = 0; i < k; i++) { for (int j = 0; j < data.dim; j++) { centers[i][j] = centers[i][j] / centerCounts[i]; centers[i][j] = Double.valueOf(df.format(centers[i][j])); } } } /** * 计算两点欧氏距离 * * @param pa * double[] * @param pb * double[] * @param dim * int 维数 * @return double 距离 */ public static double dist(double[] pa, double[] pb, int dim) { double rv = 0; for (int i = 0; i < dim; i++) { double temp = pa[i] - pb[i]; temp = temp * temp; rv += temp; } return Math.sqrt(rv); } /** * 做Kmeans运算 * * @param k * int 聚类个数 * @param data * kmeans_data kmeans数据类 * @param param * kmeans_param kmeans参数类 * @return kmeans_result kmeans运行信息类 */ public static Kmeans_result doKmeans(int k, Kmeans_param param) { //对数据进行规一化处理,以消除大的数据的影响 normalize(data); // System.out.println("规格化处理后的数据:"); // for (int i = 0;i < data.length;i++) { // for (int j = 0;j < data.dim;j++) { // System.out.print(data.data[i][j] + " "); // } // System.out.println(); // } // 预处理 double[][] centers = new double[k][data.dim]; // 聚类中心点集 data.centers = centers; int[] centerCounts = new int[k]; // 各聚类的包含点个数 data.centerCounts = centerCounts; Arrays.fill(centerCounts, 0); int[] labels = new int[data.length]; // 各个点所属聚类标号 data.labels = labels; double[][] oldCenters = new double[k][data.dim]; // 临时缓存旧的聚类中心坐标 // 初始化聚类中心(随机或者依序选择data内的k个不重复点) if (param.initCenterMethod == Kmeans_param.CENTER_RANDOM) { // 随机选取k个初始聚类中心 Random rn = new Random(); List<Integer> seeds = new LinkedList<Integer>(); while (seeds.size() < k) { int randomInt = rn.nextInt(data.length); if (!seeds.contains(randomInt)) { seeds.add(randomInt); } } Collections.sort(seeds); for (int i = 0; i < k; i++) { int m = seeds.remove(0); for (int j = 0; j < data.dim; j++) { centers[i][j] = data.data[m][j]; } } } else { // 选取前k个点位初始聚类中心 for (int i = 0; i < k; i++) { for (int j = 0; j < data.dim; j++) { centers[i][j] = data.data[i][j]; } } } //给最初的聚类中心赋值 data.originalCenters = new double[k][data.dim]; for (int i = 0; i < k; i++) { for (int j = 0; j < data.dim; j++) { data.originalCenters[i][j] = centers[i][j]; } } // 第一轮迭代 for (int i = 0; i < data.length; i++) { double minDist = dist(data.data[i], centers[0], data.dim); int label = 0; for (int j = 1; j < k; j++) { double tempDist = dist(data.data[i], centers[j], data.dim); if (tempDist < minDist) { minDist = tempDist; label = j; } } labels[i] = label; centerCounts[label]++; } updateCenters(k, data);//更新簇中心 copyCenters(oldCenters, centers, k, data.dim); // 迭代预处理 int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS; int attempts = 1; double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA; double criteriaBreakCondition = 0; boolean[] flags = new boolean[k]; // 标记哪些中心被修改过 // 迭代 iterate: while (attempts < maxAttempts) { // 迭代次数不超过最大值,最大中心改变量不超过阈值 for (int i = 0; i < k; i++) { // 初始化中心点“是否被修改过”标记 flags[i] = false; } for (int i = 0; i < data.length; i++) { // 遍历data内所有点 double minDist = dist(data.data[i], centers[0], data.dim); int label = 0; for (int j = 1; j < k; j++) { double tempDist = dist(data.data[i], centers[j], data.dim); if (tempDist < minDist) { minDist = tempDist; label = j; } } if (label != labels[i]) { // 如果当前点被聚类到新的类别则做更新 int oldLabel = labels[i]; labels[i] = label; centerCounts[oldLabel]--; centerCounts[label]++; flags[oldLabel] = true; flags[label] = true; } } updateCenters(k, data); attempts++; // 计算被修改过的中心点最大修改量是否超过阈值 double maxDist = 0; for (int i = 0; i < k; i++) { if (flags[i]) { double tempDist = dist(centers[i], oldCenters[i], data.dim); if (maxDist < tempDist) { maxDist = tempDist; } for (int j = 0; j < data.dim; j++) { // 更新oldCenter oldCenters[i][j] = centers[i][j]; oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j])); } } } if (maxDist < criteria) { criteriaBreakCondition = maxDist; break iterate; } } // 输出信息,把属于同一类的数据连续存放 Kmeans_result rvInfo = new Kmeans_result(); int perm[] = new int[data.length]; rvInfo.perm = perm; int start[] = new int[k]; rvInfo.start = start; group_class(perm,start,k,data); rvInfo.attempts = attempts; rvInfo.criteriaBreakCondition = criteriaBreakCondition; if (param.isDisplay) { System.out.println("最初的聚类中心:"); for(int i = 0;i < data.originalCenters.length;i++){ for(int j = 0;j < data.dim;j++){ System.out.print(data.originalCenters[i][j]+" "); } System.out.print("\t类别:"+i+"\t"+"总数:"+centerCounts[i]); System.out.println(); } System.out.println("\n聚类结果--------------------------->"); int originalCount = 0; for (int i = 0;i < k;i++) { int index = data.labels[perm[start[i]]];//所属类别 int count = data.centerCounts[index];//类别中个体数目 originalCount += count; System.out.println("所属类别:" + index); for (int j = start[i];j < originalCount;j++) { for (double num:data.data[perm[j]]) { System.out.print(num+" "); } System.out.println(); } } } return rvInfo; } /** * @author TongXueQiang * @param perm 连续存放归类后的原始数据的索引 * @param start 每个类的起始索引位置 * @param k 聚类中心个数 * @param data 原始数据---二维矩阵 */ private static void group_class(int perm[],int start[],int k,Kmeans_data data){ start[0] = 0; for(int i = 1;i < k;i++){ start[i] = start[i-1] + data.centerCounts[i-1]; } for(int i = 0;i < data.length;i++){ perm[start[data.labels[i]]++] = i; } start[0] = 0; for(int i = 1;i < k;i++){ start[i] = start[i-1] + data.centerCounts[i-1]; } } /** * 规一化处理 * @param data * @author TongXueQiang */ private static void normalize(Kmeans_data data){ //1.首先计算各个列的最大和最小值,存入map中 Map<Integer,Double[]> minAndMax = new HashMap<Integer,Double[]>(); for(int i = 0;i < data.dim;i++){ Double []nums = new Double[2]; double max = data.data[0][i]; double min = data.data[data.length-1][i]; for(int j = 0;j < data.length;j++){ if(data.data[j][i] > max){ max = data.data[j][i]; } if(data.data[j][i] < min){ min = data.data[j][i]; } } nums[0] = min; nums[1] = max; minAndMax.put(i,nums); } //2.更新矩阵的值 for(int i = 0;i < data.length;i++){ for(int j = 0;j < data.dim;j++){ double minValue = minAndMax.get(j)[0]; double maxValue = minAndMax.get(j)[1]; data.data[i][j] = (data.data[i][j] - minValue)/(maxValue - minValue); data.data[i][j] = Double.valueOf(df.format(data.data[i][j])); } } } } 测试类: package com.txq.kmeans.test; import org.junit.Test; import com.txq.kmeans.Kmeans; import com.txq.kmeans.Kmeans_data; import com.txq.kmeans.Kmeans_param; public class KmeansTest { @Test public void test() { double [][]da = new double[6][]; da[0] = new double[]{1,5,132}; da[1] = new double[]{3,7,12}; da[2] = new double[]{67,23,45}; da[3] = new double[]{34,5,13}; da[4] = new double[]{12,7,21}; da[5] = new double[]{26,23,54}; Kmeans kmeans = new Kmeans(da); kmeans.doKmeans(3); } }
输出结果,注意观察:
最初的聚类中心:
0.0 0.0 1.0 类别:0 总数:1
0.03 0.11 0.0 类别:1 总数:3
0.5 0.0 0.01 类别:2 总数:2
聚类结果--------------------------->
所属类别:0
0.0 0.0 1.0
所属类别:1
0.03 0.11 0.0
0.5 0.0 0.01
0.17 0.11 0.07
所属类别:2
1.0 1.0 0.28
0.38 1.0 0.35
观察这个结果,发现,随机初始化的三个簇中心,其中有两个的欧氏距离非常接近,属于同一类的。这种情况,聚类结果,就会有偏差,很不合理。
最初的聚类中心:
0.03 0.11 0.0 类别:0 总数:4
1.0 1.0 0.28 类别:1 总数:1
0.38 1.0 0.35 类别:2 总数:1
聚类结果--------------------------->
所属类别:0
0.0 0.0 1.0
0.03 0.11 0.0
0.5 0.0 0.01
0.17 0.11 0.07
所属类别:1
1.0 1.0 0.28
所属类别:2
0.38 1.0 0.35
最初的聚类中心:
1.0 1.0 0.28 类别:0 总数:1
0.5 0.0 0.01 类别:1 总数:4
0.38 1.0 0.35 类别:2 总数:1
聚类结果--------------------------->
所属类别:0
1.0 1.0 0.28
所属类别:1
0.0 0.0 1.0
0.03 0.11 0.0
0.5 0.0 0.01
0.17 0.11 0.07
所属类别:2
0.38 1.0 0.35
上述算法中,对初始簇中心严重依赖。具体来说,采用随机初始化的方式,聚类结果很不稳定,而且严重影响准确率。选取簇中心的原则是,每两个中心之间的欧氏距离应该尽量大。而且,k的数目应该有隐式的约束,太少或者太大都不合理。所以,应该同时约束上述两个因素。最好的办法是,用概率密度分析,比如高斯分布。把所有的训练数据中每两个数据的欧氏距离看作是基本变量,遵循Gaussian分布。原始数据全部归一化处理后,欧氏距离取值范围应该在:(0,√n)之间,借鉴二元分类的思想,取均值,如果大于均值的话,属于同一类的概率比较大,反之较小。其中,n为dimension.按照此种方法处理的话,会隐式地约束K的个数,使之更加合理。比如,训练数据中,每两个数据的欧氏距离>mean的中心点可能只有3个,如果你在外部调用算法时,人为地设定为4个或者5个的话,应该自动把K值降低为合理值。这样聚类的结果,一定是最优的。所以,要想达到最优效果,外部传递k值的时候,可以尽量地大,或者不设置,在不断测试的过程中,发现改为顺寻扫描效果更佳。但是,会增加时间复杂度。关于算法的精确度和时间复杂度,往往不能两全。转化为工程应用时,可以在牺牲一定精度的前提下,换取时间复杂度的提升。比如,在计算训练数据的欧氏距离的均值的时候,可以只考虑矩阵中的第一个数据与其他所有数据的欧氏距离,计算最大值和最小值,然后折中处理,不能计算所有的组合情况的E(期望)。准确度很高,而且把时间复杂度降低了一个数量级,原来O(n^2)变为O(n)。代码如下:
package com.txq.kmeans; import java.util.Map; /** * 聚类模型 * @author TongXueQiang * @date 2017/09/09 */ public class ClusterModel { public double originalCenters[][]; public int centerCounts[]; public int attempts; //最大迭代次数 public double criteriaBreakCondition; // 迭代结束时的最小阈值 public int[] labels; public int k; public int perm[];//连续存放的样本 public int start[];//每个中心开始的位置 public Map<String,Integer> identifier; public Kmeans_data data; public Map<Integer, String> iden0; public void centers(){ System.out.println("聚类中心:"); for (int i = 0; i < originalCenters.length; i++) { for (int j = 0; j < originalCenters[0].length; j++) { System.out.print(originalCenters[i][j] + " "); } System.out.print("\t"+"第" + (i+1)+"类:" + "\t" + "样本个数:" + centerCounts[i]); System.out.println(); } } public int predict(String iden){ int label = labels[identifier.get(iden)]; return label; } public void outputAllResult(){ System.out.println("\n最后聚类结果--------------------------->"); int originalCount = 0; for (int i = 0; i < k; i++) { int index = labels[perm[start[i]]]; int counts = centerCounts[index]; originalCount += counts; System.out.println("第"+(index+1)+"类成员:"); for (int j = start[i]; j < originalCount; j++) { for (double num : data.data[perm[j]]) { System.out.print(num + " "); } System.out.print(":"+iden0.get(perm[j])); System.out.println(); } } } }
package com.txq.kmeans; /** * * @author TongXueQiang * @param data 原始矩阵 * @param labels 样本所属类别 * @param centers 聚类中心 * @date 2017/09/09 */ public class Kmeans_data { public double[][] data; public int length; public int dim; public double[][] centers; public Kmeans_data(double[][] data, int length, int dim) { this.data = data; this.length = length; this.dim = dim; } }
package com.txq.kmeans; /** * 控制k_means迭代的参数 * @author TongXueQiang * @date 2017/09/09 */ public class Kmeans_param { public static final int K = 240;//系统默认的最大聚类中心个数 public static final int MAX_ATTEMPTS = 4000;//最大迭代次数 public static final double MIN_CRITERIA = 0.1; public static final double MIN_EuclideanDistance = 0.8; public double criteria = MIN_CRITERIA; //最小阈值 public int attempts = MAX_ATTEMPTS; public boolean isDisplay = true; public double min_euclideanDistance = MIN_EuclideanDistance; }
package com.txq.kmeans; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; /** * kMeans聚类算法 * @author TongXueQiang * @date 2017/09/09 */ public class Kmeans { private DecimalFormat df = new DecimalFormat("#####.00"); public Kmeans_data data = null; // feature,样本名称和索引映射 private Map<String, Integer> identifier = new HashMap<String, Integer>(); private Map<Integer, String> iden0 = new HashMap<Integer, String>(); private ClusterModel model = new ClusterModel(); /** * 文件到矩阵的映射 * @param path * @return * @throws Exception */ public double[][] fileToMatrix(String path) throws Exception { List<String> contents = new ArrayList<String>(); model.identifier = identifier; model.iden0 = iden0; FileInputStream file = null; InputStreamReader inputFileReader = null; BufferedReader reader = null; String str = null; int rows = 0; int dim = 0; try { file = new FileInputStream(path); inputFileReader = new InputStreamReader(file, "utf-8"); reader = new BufferedReader(inputFileReader); // 一次读入一行,直到读入null为文件结束 while ((str = reader.readLine()) != null) { contents.add(str); ++rows; } reader.close(); } catch (IOException e) { e.printStackTrace(); return null; } finally { if (reader != null) { try { reader.close(); } catch (IOException e1) { } } } String[] strs = contents.get(0).split(":"); dim = strs[0].split(" ").length; double[][] da = new double[rows][dim]; for (int j = 0; j < contents.size(); j++) { strs = contents.get(j).split(":"); identifier.put(strs[1], j); iden0.put(j, strs[1]); String[] feature = strs[0].split(" "); for (int i = 0; i < dim; i++) { da[j][i] = Double.parseDouble(feature[i]); } } return da; } /** * 清零操作 * @param matrix * @param highDim * @param lowDim */ private void setDouble2Zero(double[][] matrix, int highDim, int lowDim) { for (int i = 0; i < highDim; i++) { for (int j = 0; j < lowDim; j++) { matrix[i][j] = 0; } } } /** * 聚类中心拷贝 * @param dests * @param sources * @param highDim * @param lowDim */ private void copyCenters(double[][] dests, double[][] sources, int highDim, int lowDim) { for (int i = 0; i < highDim; i++) { for (int j = 0; j < lowDim; j++) { dests[i][j] = sources[i][j]; } } } /** * 更新聚类中心 * @param k * @param data */ private void updateCenters(int k, Kmeans_data data) { double[][] centers = data.centers; setDouble2Zero(centers, k, data.dim); int[] labels = model.labels; int[] centerCounts = model.centerCounts; for (int i = 0; i < data.dim; i++) { for (int j = 0; j < data.length; j++) { centers[labels[j]][i] += data.data[j][i]; } } for (int i = 0; i < k; i++) { for (int j = 0; j < data.dim; j++) { centers[i][j] = centers[i][j] / centerCounts[i]; } } } /** * 计算欧氏距离 * @param pa * @param pb * @param dim * @return */ public double dist(double[] pa, double[] pb, int dim) { double rv = 0; for (int i = 0; i < dim; i++) { double temp = pa[i] - pb[i]; temp = temp * temp; rv += temp; } return Math.sqrt(rv); } /** * 样本训练,需要人为设定k值(聚类中心数目) * @param k * @param data * @return * @throws Exception */ public ClusterModel train(String path, int k) throws Exception { double[][] matrix = fileToMatrix(path); data = new Kmeans_data(matrix, matrix.length, matrix[0].length); return train(k, new Kmeans_param()); } /** * 样本训练(系统默认最优聚类中心数目) * @param data * @return * @throws Exception */ public ClusterModel train(String path) throws Exception { double[][] matrix = fileToMatrix(path); data = new Kmeans_data(matrix, matrix.length, matrix[0].length); return train(new Kmeans_param()); } private ClusterModel train(Kmeans_param param) { int k = Kmeans_param.K; // 首先进行数据归一化处理 normalize(data); // 计算第一个样本和后面的所有样本的欧氏距离,存入list中然后计算均值,作为聚类中心选取的依据 List<Double> dists = new ArrayList<Double>(); for (int i = 1; i < data.length; i++) { dists.add(dist(data.data[0], data.data[i], data.dim)); } param.min_euclideanDistance = Double.valueOf(df.format((Collections.max(dists) + Collections.min(dists)) / 2)); double euclideanDistance = param.min_euclideanDistance > 0 ? param.min_euclideanDistance : Kmeans_param.MIN_EuclideanDistance; int centerIndexes[] = new int[k];// 收集聚类中心索引的数组 int countCenter = 0;// 动态表示中心的数目 int count = 0;// 计数器 centerIndexes[0] = 0; countCenter++; for (int i = 1; i < data.length; i++) { for (int j = 0; j < countCenter; j++) { if (dist(data.data[i], data.data[centerIndexes[j]], data.dim) > euclideanDistance) { count++; } } if (count == countCenter) { centerIndexes[countCenter++] = i; } count = 0; } double[][] centers = new double[countCenter][data.dim]; // 聚类中心 data.centers = centers; int[] centerCounts = new int[countCenter]; // 聚类中心的样本个数 model.centerCounts = centerCounts; Arrays.fill(centerCounts, 0); int[] labels = new int[data.length]; // 样本的类别 model.labels = labels; double[][] oldCenters = new double[countCenter][data.dim]; // 存储旧的聚类中心 // 给聚类中心赋值 for (int i = 0; i < countCenter; i++) { int m = centerIndexes[i]; for (int j = 0; j < data.dim; j++) { centers[i][j] = data.data[m][j]; } } // 给最初始的聚类中心赋值 model.originalCenters = new double[countCenter][data.dim]; for (int i = 0; i < countCenter; i++) { for (int j = 0; j < data.dim; j++) { model.originalCenters[i][j] = centers[i][j]; } } //初始聚类 for (int i = 0; i < data.length; i++) { double minDist = dist(data.data[i], centers[0], data.dim); int label = 0; for (int j = 1; j < countCenter; j++) { double tempDist = dist(data.data[i], centers[j], data.dim); if (tempDist < minDist) { minDist = tempDist; label = j; } } labels[i] = label; centerCounts[label]++; } updateCenters(countCenter, data); copyCenters(oldCenters, centers, countCenter, data.dim); // 迭代预处理 int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS; int attempts = 1; double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA; double criteriaBreakCondition = 0; boolean[] flags = new boolean[k]; // 用来表示聚类中心是否发生变化 // 迭代 iterate: while (attempts < maxAttempts) { // 迭代次数不超过最大值,最大中心改变量不超过阈值 for (int i = 0; i < countCenter; i++) { // 初始化中心点"是否被修改过"标记 flags[i] = false; } for (int i = 0; i < data.length; i++) { double minDist = dist(data.data[i], centers[0], data.dim); int label = 0; for (int j = 1; j < countCenter; j++) { double tempDist = dist(data.data[i], centers[j], data.dim); if (tempDist < minDist) { minDist = tempDist; label = j; } } if (label != labels[i]) { // 如果当前点被聚类到新的类别则做更新 int oldLabel = labels[i]; labels[i] = label; centerCounts[oldLabel]--; centerCounts[label]++; flags[oldLabel] = true; flags[label] = true; } } updateCenters(countCenter, data); attempts++; // 计算被修改过的中心点最大修改量是否超过阈值 double maxDist = 0; for (int i = 0; i < countCenter; i++) { if (flags[i]) { double tempDist = dist(centers[i], oldCenters[i], data.dim); if (maxDist < tempDist) { maxDist = tempDist; } for (int j = 0; j < data.dim; j++) { // 更新oldCenter oldCenters[i][j] = centers[i][j]; oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j])); } } } if (maxDist < criteria) { criteriaBreakCondition = maxDist; break iterate; } } // 把结果存储到ClusterModel中 ClusterModel rvInfo = outputClusterInfo(criteriaBreakCondition, countCenter, attempts, param, centerCounts); return rvInfo; } private ClusterModel train(int k, Kmeans_param param) { // 首先进行数据归一化处理 normalize(data); List<Double> dists = new ArrayList<Double>(); for (int i = 1; i < data.length; i++) { dists.add(dist(data.data[0], data.data[i], data.dim)); } param.min_euclideanDistance = Double.valueOf(df.format((Collections.max(dists) + Collections.min(dists)) / 2)); double euclideanDistance = param.min_euclideanDistance > 0 ? param.min_euclideanDistance : Kmeans_param.MIN_EuclideanDistance; double[][] centers = new double[k][data.dim]; data.centers = centers; int[] centerCounts = new int[k]; model.centerCounts = centerCounts; Arrays.fill(centerCounts, 0); int[] labels = new int[data.length]; model.labels = labels; double[][] oldCenters = new double[k][data.dim]; int centerIndexes[] = new int[k]; int countCenter = 0; int count = 0; centerIndexes[0] = 0; countCenter++; for (int i = 1; i < data.length; i++) { for (int j = 0; j < countCenter; j++) { if (dist(data.data[i], data.data[centerIndexes[j]], data.dim) > euclideanDistance) { count++; } } if (count == countCenter) { centerIndexes[countCenter++] = i; } count = 0; if (countCenter == k) { break; } if (countCenter < k && i == data.length - 1) { k = countCenter; break; } } for (int i = 0; i < k; i++) { int m = centerIndexes[i]; for (int j = 0; j < data.dim; j++) { centers[i][j] = data.data[m][j]; } } model.originalCenters = new double[k][data.dim]; for (int i = 0; i < k; i++) { for (int j = 0; j < data.dim; j++) { model.originalCenters[i][j] = centers[i][j]; } } for (int i = 0; i < data.length; i++) { double minDist = dist(data.data[i], centers[0], data.dim); int label = 0; for (int j = 1; j < k; j++) { double tempDist = dist(data.data[i], centers[j], data.dim); if (tempDist < minDist) { minDist = tempDist; label = j; } } labels[i] = label; centerCounts[label]++; } updateCenters(k, data); copyCenters(oldCenters, centers, k, data.dim); int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS; int attempts = 1; double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA; double criteriaBreakCondition = 0; boolean[] flags = new boolean[k]; iterate: while (attempts < maxAttempts) { for (int i = 0; i < k; i++) { flags[i] = false; } for (int i = 0; i < data.length; i++) { double minDist = dist(data.data[i], centers[0], data.dim); int label = 0; for (int j = 1; j < k; j++) { double tempDist = dist(data.data[i], centers[j], data.dim); if (tempDist < minDist) { minDist = tempDist; label = j; } } if (label != labels[i]) { int oldLabel = labels[i]; labels[i] = label; centerCounts[oldLabel]--; centerCounts[label]++; flags[oldLabel] = true; flags[label] = true; } } updateCenters(k, data); attempts++; double maxDist = 0; for (int i = 0; i < k; i++) { if (flags[i]) { double tempDist = dist(centers[i], oldCenters[i], data.dim); if (maxDist < tempDist) { maxDist = tempDist; } for (int j = 0; j < data.dim; j++) { // 锟斤拷锟斤拷oldCenter oldCenters[i][j] = centers[i][j]; oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j])); } } } if (maxDist < criteria) { criteriaBreakCondition = maxDist; break iterate; } } ClusterModel rvInfo = outputClusterInfo(criteriaBreakCondition, k, attempts, param, centerCounts); return rvInfo; } /** * 把聚类结果存储到Model中 * @param criteriaBreakCondition * @param k * @param attempts * @param param * @param centerCounts * @return */ private ClusterModel outputClusterInfo(double criteriaBreakCondition, int k, int attempts, Kmeans_param param, int[] centerCounts) { model.data = data; model.k = k; int perm[] = new int[data.length]; model.perm = perm; int start[] = new int[k]; model.start = start; group_class(perm, start, k, data); return model; } /** * 把聚类样本按所属类别连续存储 * @param perm * @param start * @param k * @param data */ private void group_class(int perm[], int start[], int k, Kmeans_data data) { start[0] = 0; for (int i = 1; i < k; i++) { start[i] = start[i - 1] + model.centerCounts[i - 1]; } for (int i = 0; i < data.length; i++) { perm[start[model.labels[i]]++] = i; } start[0] = 0; for (int i = 1; i < k; i++) { start[i] = start[i - 1] + model.centerCounts[i - 1]; } } /** * 数据归一化处理 * @param data * @author TongXueQiang */ private void normalize(Kmeans_data data) { Map<Integer, Double[]> minAndMax = new HashMap<Integer, Double[]>(); for (int i = 0; i < data.dim; i++) { Double[] nums = new Double[2]; double max = data.data[0][i]; double min = data.data[data.length - 1][i]; for (int j = 0; j < data.length; j++) { if (data.data[j][i] > max) { max = data.data[j][i]; } if (data.data[j][i] < min) { min = data.data[j][i]; } } nums[0] = min; nums[1] = max; minAndMax.put(i, nums); } for (int i = 0; i < data.length; i++) { for (int j = 0; j < data.dim; j++) { double minValue = minAndMax.get(j)[0]; double maxValue = minAndMax.get(j)[1]; data.data[i][j] = (data.data[i][j] - minValue) / (maxValue - minValue); data.data[i][j] = Double.valueOf(df.format(data.data[i][j])); } } } }
测试代码:
package com.txq.kmeans.test; import org.junit.Test; import com.txq.kmeans.ClusterModel; import com.txq.kmeans.Kmeans; /** * * @author XueQiang Tong * train方法有两种,一个不需要传递K值,算法内部自动处理为最优值,此为最细粒度聚类,另一个需要传递K值,k值大小任意,当k值>算法内部最优值时,自动调整 为最优值 * 利用model预测时,只需传递feature标识 */ public class KmeansTest { @Test public void test() throws Exception { Kmeans kmeans = new Kmeans(); String path = "F:\\kmeans.txt"; ClusterModel model = kmeans.train(path); model.centers(); System.out.println("中国属于第" + (model.predict("中国")+1)+"类"); model.outputAllResult(); System.out.println("-------------------------------------------------------------------------------------"); model = kmeans.train(path,100000); model.centers(); System.out.println("中国属于第" + (model.predict("中国")+1)+"类"); model.outputAllResult(); } }
看一下输出结果:
聚类中心:
1.0 1.0 0.5 第1类: 样本个数:10
0.33 0.0 0.19 第2类: 样本个数:2
0.24 0.76 0.25 第3类: 样本个数:2
0.7 0.56 1.0 第4类: 样本个数:1
中国属于第1类
最后聚类结果--------------------------->
第1类成员:
1.0 1.0 0.5 :中国
1.0 1.0 0.0 :伊拉克
1.0 0.76 0.5 :卡塔尔
1.0 0.76 0.5 :阿联酋
0.7 0.76 0.25 :乌兹别克斯坦
1.0 1.0 0.5 :泰国
1.0 1.0 0.25 :越南
1.0 1.0 0.5 :阿曼
0.7 0.76 0.5 :巴林
1.0 1.0 0.5 :印尼
第2类成员:
0.33 0.0 0.19 :日本
0.0 0.15 0.12 :韩国
第3类成员:
0.24 0.76 0.25 :伊朗
0.33 0.76 0.06 :沙特
第4类成员:
0.7 0.56 1.0 :朝鲜
-------------------------------------------------------------------------------------
聚类中心:
1.0 1.0 0.5 第1类: 样本个数:10
0.33 0.0 0.19 第2类: 样本个数:2
0.24 0.76 0.25 第3类: 样本个数:2
0.7 0.56 1.0 第4类: 样本个数:1
中国属于第1类
最后聚类结果--------------------------->
第1类成员:
1.0 1.0 0.5 :中国
1.0 1.0 0.0 :伊拉克
1.0 0.76 0.5 :卡塔尔
1.0 0.76 0.5 :阿联酋
0.7 0.76 0.25 :乌兹别克斯坦
1.0 1.0 0.5 :泰国
1.0 1.0 0.25 :越南
1.0 1.0 0.5 :阿曼
0.7 0.76 0.5 :巴林
1.0 1.0 0.5 :印尼
第2类成员:
0.33 0.0 0.19 :日本
0.0 0.15 0.12 :韩国
第3类成员:
0.24 0.76 0.25 :伊朗
0.33 0.76 0.06 :沙特
第4类成员:
0.7 0.56 1.0 :朝鲜
现在更改一下k值,设为3,看看效果:
聚类中心:
1.0 1.0 0.5 第1类: 样本个数:10
0.33 0.0 0.19 第2类: 样本个数:2
0.24 0.76 0.25 第3类: 样本个数:2
0.7 0.56 1.0 第4类: 样本个数:1
中国属于第1类
最后聚类结果--------------------------->
第1类成员:
1.0 1.0 0.5 :中国
1.0 1.0 0.0 :伊拉克
1.0 0.76 0.5 :卡塔尔
1.0 0.76 0.5 :阿联酋
0.7 0.76 0.25 :乌兹别克斯坦
1.0 1.0 0.5 :泰国
1.0 1.0 0.25 :越南
1.0 1.0 0.5 :阿曼
0.7 0.76 0.5 :巴林
1.0 1.0 0.5 :印尼
第2类成员:
0.33 0.0 0.19 :日本
0.0 0.15 0.12 :韩国
第3类成员:
0.24 0.76 0.25 :伊朗
0.33 0.76 0.06 :沙特
第4类成员:
0.7 0.56 1.0 :朝鲜
-------------------------------------------------------------------------------------
聚类中心:
1.0 1.0 0.5 第1类: 样本个数:11
0.33 0.0 0.19 第2类: 样本个数:2
0.24 0.76 0.25 第3类: 样本个数:2
中国属于第1类
最后聚类结果--------------------------->
第1类成员:
1.0 1.0 0.5 :中国
1.0 1.0 0.0 :伊拉克
1.0 0.76 0.5 :卡塔尔
1.0 0.76 0.5 :阿联酋
0.7 0.76 0.25 :乌兹别克斯坦
1.0 1.0 0.5 :泰国
1.0 1.0 0.25 :越南
1.0 1.0 0.5 :阿曼
0.7 0.76 0.5 :巴林
0.7 0.56 1.0 :朝鲜
1.0 1.0 0.5 :印尼
第2类成员:
0.33 0.0 0.19 :日本
0.0 0.15 0.12 :韩国
第3类成员:
0.24 0.76 0.25 :伊朗
0.33 0.76 0.06 :沙特
现在评估一下算法的准确度:
数据分析:以下数据为欧氏距离对比,其中簇中心为中国,日本,朝鲜和伊朗,分别代表了4个梯队。欧氏距离的均值为0.68。
中国-朝鲜:0.730479294709987
巴林-朝鲜:0.5385164807134504
巴林-中国:0.38418745424597095
上述数据中,中国-朝鲜是准确的,属于不同类别。巴林与簇中心朝鲜的欧氏距离大于与中国的距离,所以聚类的时候,与中国是一类。由于是按顺序扫描,降低了不确定性和时间复杂度。如果训练数据顺序调整了,选取了巴林作为簇中心的话,虽然从算法上看是准确的,但是,效果并不是最好的。这个算法的缺点是,对训练数据的顺序比较敏感。但是,总体情况,此算法的准确率非常高,而且聚类结果是稳定的,并且与原来相比,降低了一个数量级的时间复杂度,可以满足实际工程需要。
更多精彩博客推荐,语义相似度经典:http://www.cnblogs.com/txq157/p/7425781.html
1 package com.txq.kmeans; 2 3 /** 4 * 5 * @author TongXueQiang 6 * @param data 原始矩阵 7 * @param labels 所属类别 8 * @param centers 簇中心 9 */ 10 public class Kmeans_data { 11 public double[][] data; 12 public int length; 13 public int dim; 14 public double[][] centers; 15 16 public Kmeans_data(double[][] data, int length, int dim) { 17 this.data = data; 18 this.length = length; 19 this.dim = dim; 20 } 21
1 package com.txq.kmeans; 2 3 import java.io.BufferedReader; 4 import java.io.FileReader; 5 import java.text.DecimalFormat; 6 import java.util.ArrayList; 7 import java.util.Arrays; 8 import java.util.Collections; 9 import java.util.HashMap; 10 import java.util.HashSet; 11 import java.util.List; 12 import java.util.Map; 13 import java.util.Random; 14 import java.util.Set; 15 16 /** 17 * Kmeans聚类算法 18 * 19 * @author TongXueQiang 20 * @date 2016/11/09 21 */ 22 public class Kmeans { 23 private DecimalFormat df = new DecimalFormat("#####.00"); 24 public Kmeans_data data = null; 25 //feature身份标识与索引的映射 26 private Map<String,Integer> identifier = new HashMap<String,Integer>(); 27 private Map<Integer,String> iden0 = new HashMap<Integer,String>(); 28 private ClusterModel model = new ClusterModel(); 29 30 /** 31 * 文件到矩阵的映射 32 * @param path 33 * @return 34 * @throws Exception 35 */ 36 public double [][] fileToMatrix(String path) throws Exception{ 37 List<String> contents = new ArrayList<String>(); 38 model.identifier = identifier; 39 model.iden0 = iden0; 40 41 BufferedReader bf = new BufferedReader(new FileReader(path)); 42 String str = null; 43 int rows = 0; 44 int dim = 0; 45 46 while((str = bf.readLine()) != null) { 47 contents.add(str); 48 ++rows; 49 } 50 bf.close(); 51 String []strs = contents.get(0).split(":"); 52 dim = strs[0].split(" ").length; 53 54 double [][]da = new double[rows][dim]; 55 56 for(int j = 0;j < contents.size();j++){ 57 strs = contents.get(j).split(":"); 58 identifier.put(strs[1],j); 59 iden0.put(j,strs[1]); 60 String []feature = strs[0].split(" "); 61 for(int i = 0;i < dim;i++){ 62 da[j][i] = Double.parseDouble(feature[i]); 63 } 64 } 65 66 return da; 67 } 68 69 /** 70 * double[][] 元素全置 71 * 72 * @param matrix 73 * double[][] 74 * @param highDim 75 * int 76 * @param lowDim 77 * int <br/> 78 * double[highDim][lowDim] 79 */ 80 private void setDouble2Zero(double[][] matrix, int highDim, int lowDim) { 81 for (int i = 0; i < highDim; i++) { 82 for (int j = 0; j < lowDim; j++) { 83 matrix[i][j] = 0; 84 } 85 } 86 } 87 88 /** 89 * 拷贝源二维矩阵元素到目标二维矩阵。 foreach (dests[highDim][lowDim] = 90 * sources[highDim][lowDim]); 91 * 92 * @param dests 93 * double[][] 94 * @param sources 95 * double[][] 96 * @param highDim 97 * int 98 * @param lowDim 99 * int 100 */ 101 private void copyCenters(double[][] dests, double[][] sources, int highDim, int lowDim) { 102 for (int i = 0; i < highDim; i++) { 103 for (int j = 0; j < lowDim; j++) { 104 dests[i][j] = sources[i][j]; 105 } 106 } 107 } 108 109 /** 110 * 更新聚类中心坐标 111 * 112 * @param k 113 * int 分类个数 114 * @param data 115 * kmeans_data 116 */ 117 private void updateCenters(int k, Kmeans_data data) { 118 double[][] centers = data.centers; 119 setDouble2Zero(centers, k, data.dim);// 归零处理 120 int[] labels = model.labels; 121 int[] centerCounts = model.centerCounts; 122 for (int i = 0; i < data.dim; i++) { 123 for (int j = 0; j < data.length; j++) { 124 centers[labels[j]][i] += data.data[j][i]; 125 } 126 } 127 for (int i = 0; i < k; i++) { 128 for (int j = 0; j < data.dim; j++) { 129 centers[i][j] = centers[i][j] / centerCounts[i]; 130 // centers[i][j] = 131 // Double.parseDouble(df.format(centers[i][j]).toString()); 132 } 133 } 134 } 135 136 /** 137 * 计算两点欧氏距离 138 * 139 * @param pa 140 * double[] 141 * @param pb 142 * double[] 143 * @param dim 144 * int 维数 145 * @return double 距离 146 */ 147 public double dist(double[] pa, double[] pb, int dim) { 148 double rv = 0; 149 for (int i = 0; i < dim; i++) { 150 double temp = pa[i] - pb[i]; 151 temp = temp * temp; 152 rv += temp; 153 } 154 return Math.sqrt(rv); 155 } 156 157 /** 158 * 外部调用有k值的聚类方法,非最优解 159 * 160 * @param k 161 * @param data 162 * @return 163 * @throws Exception 164 */ 165 public ClusterModel train(String path,int k) throws Exception { 166 double [][]matrix = fileToMatrix(path); 167 data = new Kmeans_data(matrix, matrix.length, matrix[0].length); 168 return train(k, new Kmeans_param()); 169 } 170 171 /** 172 * 外部调用无k值的聚类方法,最优解 173 * 174 * @param data 175 * @return 176 * @throws Exception 177 */ 178 public ClusterModel train(String path) throws Exception { 179 double [][]matrix = fileToMatrix(path);//文件到矩阵的映射 180 data = new Kmeans_data(matrix, matrix.length, matrix[0].length); 181 return train(new Kmeans_param()); 182 } 183 184 private ClusterModel train(Kmeans_param param) { 185 int k = param.K; 186 // 对数据进行规一化处理,以消除大的数据的影响 187 normalize(data); 188 189 // 寻找欧氏距离的均值 190 List<Double> dists = new ArrayList<Double>(); 191 for (int i = 1; i < data.length; i++) { 192 dists.add(dist(data.data[0], data.data[i], data.dim)); 193 194 } 195 param.min_euclideanDistance = Double.valueOf(df.format((Collections.max(dists)+Collections.min(dists))/2)); 196 double euclideanDistance = param.min_euclideanDistance > 0 ? param.min_euclideanDistance 197 : Kmeans_param.MIN_EuclideanDistance; 198 199 // 预处理 200 double[][] centers = new double[k][data.dim]; // 聚类中心点集 201 data.centers = centers; 202 int[] centerCounts = new int[k]; // 各聚类的包含点个数 203 model.centerCounts = centerCounts; 204 Arrays.fill(centerCounts, 0); 205 int[] labels = new int[data.length]; // 各个点所属聚类标号 206 model.labels = labels; 207 double[][] oldCenters = new double[k][data.dim]; // 临时缓存旧的聚类中心坐标 208 209 // 初始化聚类中心 210 int centerIndexes[] = new int[16];// 预初始化16个簇组中心 211 int countCenter = 0;// 动态表示簇中心个数 212 int count = 0;// 计数器 213 centerIndexes[0] = 0; 214 countCenter++; 215 for (int i = 1; i < data.length; i++) { 216 for (int j = 0; j < countCenter; j++) { 217 if (dist(data.data[i], data.data[centerIndexes[j]], data.dim) > euclideanDistance) { 218 count++; 219 } 220 } 221 if (count == countCenter) { 222 centerIndexes[countCenter++] = i; 223 } 224 count = 0;// 计数器清零 225 // 如果达到了k值,提前终止 226 if (countCenter == k) { 227 break; 228 } 229 // 如果遍历了整个数据,仍然没有找到合适的中心点的话,把k自动降低为countCeneter,使簇中心个数更加趋于合理化 230 if (countCenter < k && i == data.length - 1) { 231 k = countCenter; 232 break; 233 } 234 } 235 // 给centers赋值 236 for (int i = 0; i < k; i++) { 237 int m = centerIndexes[i]; 238 for (int j = 0; j < data.dim; j++) { 239 centers[i][j] = data.data[m][j]; 240 } 241 } 242 243 // 给最初的聚类中心赋值 244 model.originalCenters = new double[k][data.dim]; 245 for (int i = 0; i < k; i++) { 246 for (int j = 0; j < data.dim; j++) { 247 model.originalCenters[i][j] = centers[i][j]; 248 } 249 } 250 251 // 第一轮迭代 252 for (int i = 0; i < data.length; i++) { 253 double minDist = dist(data.data[i], centers[0], data.dim); 254 int label = 0; 255 for (int j = 1; j < k; j++) { 256 double tempDist = dist(data.data[i], centers[j], data.dim); 257 if (tempDist < minDist) { 258 minDist = tempDist; 259 label = j; 260 } 261 } 262 labels[i] = label; 263 centerCounts[label]++; 264 } 265 updateCenters(k, data);// 更新簇中心 266 copyCenters(oldCenters, centers, k, data.dim); 267 268 // 迭代预处理 269 int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS; 270 int attempts = 1; 271 double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA; 272 double criteriaBreakCondition = 0; 273 boolean[] flags = new boolean[k]; // 标记哪些中心被修改过 274 275 // 迭代 276 iterate: while (attempts < maxAttempts) { // 迭代次数不超过最大值,最大中心改变量不超过阈值 277 for (int i = 0; i < k; i++) { // 初始化中心点“是否被修改过”标记 278 flags[i] = false; 279 } 280 for (int i = 0; i < data.length; i++) { // 遍历data内所有点 281 double minDist = dist(data.data[i], centers[0], data.dim); 282 int label = 0; 283 for (int j = 1; j < k; j++) { 284 double tempDist = dist(data.data[i], centers[j], data.dim); 285 if (tempDist < minDist) { 286 minDist = tempDist; 287 label = j; 288 } 289 } 290 if (label != labels[i]) { // 如果当前点被聚类到新的类别则做更新 291 int oldLabel = labels[i]; 292 labels[i] = label; 293 centerCounts[oldLabel]--; 294 centerCounts[label]++; 295 flags[oldLabel] = true; 296 flags[label] = true; 297 } 298 } 299 updateCenters(k, data); 300 attempts++; 301 302 // 计算被修改过的中心点最大修改量是否超过阈值 303 double maxDist = 0; 304 for (int i = 0; i < k; i++) { 305 if (flags[i]) { 306 double tempDist = dist(centers[i], oldCenters[i], data.dim); 307 if (maxDist < tempDist) { 308 maxDist = tempDist; 309 } 310 for (int j = 0; j < data.dim; j++) { // 更新oldCenter 311 oldCenters[i][j] = centers[i][j]; 312 oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j])); 313 } 314 } 315 } 316 if (maxDist < criteria) { 317 criteriaBreakCondition = maxDist; 318 break iterate; 319 } 320 } 321 //输出训练模型 322 ClusterModel rvInfo = outputClusterInfo(criteriaBreakCondition, k, attempts, param, centerCounts); 323 return rvInfo; 324 } 325 326 /** 327 * 做Kmeans运算,需要手动设置K值 328 * 329 * @param k 330 * int 聚类个数 331 * @param data 332 * kmeans_data kmeans数据类 333 * @param param 334 * kmeans_param kmeans参数类 335 * @return kmeans_result kmeans运行信息类 336 */ 337 private ClusterModel train(int k, Kmeans_param param) { 338 // 对数据进行规一化处理,以消除大的数据的影响 339 normalize(data); 340 341 // 寻找欧氏距离的均值 342 List<Double> dists = new ArrayList<Double>(); 343 for (int i = 1; i < data.length; i++) { 344 dists.add(dist(data.data[0], data.data[i], data.dim)); 345 } 346 347 param.min_euclideanDistance = Double.valueOf(df.format((Collections.max(dists)+Collections.min(dists))/2)); 348 double euclideanDistance = param.min_euclideanDistance > 0 ? param.min_euclideanDistance 349 : Kmeans_param.MIN_EuclideanDistance; 350 351 // 预处理 352 double[][] centers = new double[k][data.dim]; // 聚类中心点集 353 data.centers = centers; 354 int[] centerCounts = new int[k]; // 各聚类的包含点个数 355 model.centerCounts = centerCounts; 356 Arrays.fill(centerCounts, 0); 357 int[] labels = new int[data.length]; // 各个点所属聚类标号 358 model.labels = labels; 359 double[][] oldCenters = new double[k][data.dim]; // 临时缓存旧的聚类中心坐标 360 361 // 初始化聚类中心(依序选择data内的k个不重复点) 362 int centerIndexes[] = new int[16];// 预初始化16个簇组中心 363 int countCenter = 0;// 动态表示簇中心个数 364 int count = 0;// 计数器 365 centerIndexes[0] = 0; 366 countCenter++; 367 for (int i = 1; i < data.length; i++) { 368 for (int j = 0; j < countCenter; j++) { 369 if (dist(data.data[i], data.data[centerIndexes[j]], data.dim) > euclideanDistance) { 370 count++; 371 } 372 } 373 if (count == countCenter) { 374 centerIndexes[countCenter++] = i; 375 } 376 count = 0;// 计数器清零 377 // 如果达到了k值,提前终止 378 if (countCenter == k) { 379 break; 380 } 381 // 如果遍历了整个数据,仍然没有找到合适的中心点的话,把k自动降低为countCeneter,使簇中心个数更加趋于合理化 382 if (countCenter < k && i == data.length - 1) { 383 k = countCenter; 384 break; 385 } 386 } 387 // 给centers赋值 388 for (int i = 0; i < k; i++) { 389 int m = centerIndexes[i]; 390 for (int j = 0; j < data.dim; j++) { 391 centers[i][j] = data.data[m][j]; 392 } 393 } 394 395 // 给最初的聚类中心赋值 396 model.originalCenters = new double[k][data.dim]; 397 for (int i = 0; i < k; i++) { 398 for (int j = 0; j < data.dim; j++) { 399 model.originalCenters[i][j] = centers[i][j]; 400 } 401 } 402 403 // 第一轮迭代 404 for (int i = 0; i < data.length; i++) { 405 double minDist = dist(data.data[i], centers[0], data.dim); 406 int label = 0; 407 for (int j = 1; j < k; j++) { 408 double tempDist = dist(data.data[i], centers[j], data.dim); 409 if (tempDist < minDist) { 410 minDist = tempDist; 411 label = j; 412 } 413 } 414 labels[i] = label; 415 centerCounts[label]++; 416 } 417 updateCenters(k, data);// 更新簇中心 418 copyCenters(oldCenters, centers, k, data.dim); 419 420 // 迭代预处理 421 int maxAttempts = param.attempts > 0 ? param.attempts : Kmeans_param.MAX_ATTEMPTS; 422 int attempts = 1; 423 double criteria = param.criteria > 0 ? param.criteria : Kmeans_param.MIN_CRITERIA; 424 double criteriaBreakCondition = 0; 425 boolean[] flags = new boolean[k]; // 标记哪些中心被修改过 426 427 // 迭代 428 iterate: while (attempts < maxAttempts) { // 迭代次数不超过最大值,最大中心改变量不超过阈值 429 for (int i = 0; i < k; i++) { // 初始化中心点"是否被修改过"标记 430 flags[i] = false; 431 } 432 for (int i = 0; i < data.length; i++) { // 遍历data内所有点 433 double minDist = dist(data.data[i], centers[0], data.dim); 434 int label = 0; 435 for (int j = 1; j < k; j++) { 436 double tempDist = dist(data.data[i], centers[j], data.dim); 437 if (tempDist < minDist) { 438 minDist = tempDist; 439 label = j; 440 } 441 } 442 if (label != labels[i]) { // 如果当前点被聚类到新的类别则做更新 443 int oldLabel = labels[i]; 444 labels[i] = label; 445 centerCounts[oldLabel]--; 446 centerCounts[label]++; 447 flags[oldLabel] = true; 448 flags[label] = true; 449 } 450 } 451 updateCenters(k, data); 452 attempts++; 453 454 // 计算被修改过的中心点最大修改量是否超过阈值 455 double maxDist = 0; 456 for (int i = 0; i < k; i++) { 457 if (flags[i]) { 458 double tempDist = dist(centers[i], oldCenters[i], data.dim); 459 if (maxDist < tempDist) { 460 maxDist = tempDist; 461 } 462 for (int j = 0; j < data.dim; j++) { // 更新oldCenter 463 oldCenters[i][j] = centers[i][j]; 464 oldCenters[i][j] = Double.valueOf(df.format(oldCenters[i][j])); 465 } 466 } 467 } 468 if (maxDist < criteria) { 469 criteriaBreakCondition = maxDist; 470 break iterate; 471 } 472 } 473 474 // 输出信息,把属于同一类的数据连续存放 475 ClusterModel rvInfo = outputClusterInfo(criteriaBreakCondition, k, attempts, param, centerCounts); 476 return rvInfo; 477 } 478 479 /** 480 * 输出聚类结果 481 * 482 * @param criteriaBreakCondition 483 * @param k 484 * @param attempts 485 * @param param 486 * @param centerCounts 487 * @return 488 */ 489 private ClusterModel outputClusterInfo(double criteriaBreakCondition, int k, int attempts, Kmeans_param param, 490 int[] centerCounts) { 491 // 输出信息,把属于同一类的数据连续存放 492 model.data = data; 493 model.k = k; 494 int perm[] = new int[data.length]; 495 model.perm = perm; 496 int start[] = new int[k]; 497 model.start = start; 498 group_class(perm, start, k, data); 499 return model; 500 } 501 502 /** 503 * @author TongXueQiang 504 * @param perm 505 * 连续存放归类后的原始数据的索引 506 * @param start 507 * 每个类的起始索引位置 508 * @param k 509 * 聚类中心个数 510 * @param data 511 * 原始数据---二维矩阵 512 */ 513 private void group_class(int perm[], int start[], int k, Kmeans_data data) { 514 start[0] = 0; 515 for (int i = 1; i < k; i++) { 516 start[i] = start[i - 1] + model.centerCounts[i - 1]; 517 } 518 519 for (int i = 0; i < data.length; i++) { 520 perm[start[model.labels[i]]++] = i; 521 } 522 523 start[0] = 0; 524 for (int i = 1; i < k; i++) { 525 start[i] = start[i - 1] + model.centerCounts[i - 1]; 526 } 527 } 528 529 /** 530 * 规一化处理 531 * 532 * @param data 533 * @author TongXueQiang 534 */ 535 private void normalize(Kmeans_data data) { 536 // 1.首先计算各个列的最大和最小值,存入map中 537 Map<Integer, Double[]> minAndMax = new HashMap<Integer, Double[]>(); 538 for (int i = 0; i < data.dim; i++) { 539 Double[] nums = new Double[2]; 540 double max = data.data[0][i]; 541 double min = data.data[data.length - 1][i]; 542 for (int j = 0; j < data.length; j++) { 543 if (data.data[j][i] > max) { 544 max = data.data[j][i]; 545 } 546 if (data.data[j][i] < min) { 547 min = data.data[j][i]; 548 } 549 } 550 nums[0] = min; 551 nums[1] = max; 552 minAndMax.put(i, nums); 553 } 554 // 2.更新矩阵的值 555 for (int i = 0; i < data.length; i++) { 556 for (int j = 0; j < data.dim; j++) { 557 double minValue = minAndMax.get(j)[0]; 558 double maxValue = minAndMax.get(j)[1]; 559 data.data[i][j] = (data.data[i][j] - minValue) / (maxValue - minValue); 560 data.data[i][j] = Double.valueOf(df.format(data.data[i][j])); 561 } 562 } 563 } 564