算法 - k-means++
Kmeans++算法
Kmeans++算法,主要可以解决初始中心的选择问题,不可解决k的个数问题。
Kmeans++主要思想是选择的初始聚类中心要尽量的远。
做法:
1. 在输入的数据点中随机选一个作为第一个聚类中心。
2. 对于所有数据点,计算它与已有的聚类中心的最小距离D(x)
3. 选择一个数据点作为新增的聚类中心,选择原则:D(x)较大的点被选为聚类中心的概率较大。
4. 重复2~3步骤直到选出k个聚类中心。
5. 运行Kmeans算法。
package com.lfy.main; import java.util.ArrayList; import java.util.List; import java.util.Random; /** * K均值聚类算法 */ public class Kmeans { private int numOfCluster;// 分成多少簇 private int timeOfIteration;// 迭代次数 private int dataSetLength;// 数据集元素个数,即数据集的长度 private ArrayList<float[]> dataSet;// 数据集 private ArrayList<float[]> center;// 质心 private ArrayList<ArrayList<float[]>> cluster; //簇 private ArrayList<Float> sumOfErrorSquare;// 误差平方和 private Random random; /** * 设置需分组的原始数据集 * * @param dataSet */ public void setDataSet(ArrayList<float[]> dataSet) { this.dataSet = dataSet; } /** * 获取结果分组 * * @return 结果集 */ public ArrayList<ArrayList<float[]>> getCluster() { return cluster; } /** * 构造函数,传入需要分成的簇数量 * * @param numOfCluster * 簇数量,若numOfCluster<=0时,设置为1,若numOfCluster大于数据源的长度时,置为数据源的长度 */ public Kmeans(int numOfCluster) { if (numOfCluster <= 0) { numOfCluster = 1; } this.numOfCluster = numOfCluster; } /** * 初始化 */ private void init() { timeOfIteration = 0; random = new Random(); //如果调用者未初始化数据集,则采用内部测试数据集 if (dataSet == null || dataSet.size() == 0) { initDataSet(); } dataSetLength = dataSet.size(); //若numOfCluster大于数据源的长度时,置为数据源的长度 if (numOfCluster > dataSetLength) { numOfCluster = dataSetLength; } center = initCenters(); cluster = initCluster(); sumOfErrorSquare = new ArrayList<Float>(); //查看init质心的选取情况 printDataArray(center,"initCenter"); } /** * 如果调用者未初始化数据集,则采用内部测试数据集 */ private void initDataSet() { dataSet = new ArrayList<float[]>(); // 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0 float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 }, { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 }, { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } }; for (int i = 0; i < dataSetArray.length; i++) { dataSet.add(dataSetArray[i]); } } /** * 随机选取k个质点 * 初始化中心点,分成多少簇就有多少个中心点 * * @return 中心点集 */ private ArrayList<float[]> initCenters() { ArrayList<float[]> center = new ArrayList<float[]>(); int[] randoms = new int[numOfCluster]; int temp = random.nextInt(dataSetLength); randoms[0] = temp; //---------------------- List<Integer> list=new ArrayList<Integer>(); list.add(temp); //randoms数组中存放dataSet数据集的不同的下标 for (int i = 1; i < numOfCluster; i++) { // while (true) { // temp = random.nextInt(dataSetLength); // // int j=0; // for(; j<i; j++){ // if(randoms[j] == temp){ // break; // } // } // //没有与任何一个已经选定的质心重复 // //跳出外层循环,设定一个随机质心 // if (j == i) { // break; // } // } //---------------------- ArrayList<float[]> ltemp=new ArrayList<float[]>(); //从剩下的点中继续找质点 for (int k = 0; k < dataSetLength; k++) { //如果该点还没有被选择为质点,则计算它与已有的所有质点的最小距离 if(!list.contains(k)) { float[] distance = new float[numOfCluster]; for (int j = 0; j < list.size(); j++) { //某点k到已有中心点的距离 distance[j] = distance(dataSet.get(k), dataSet.get(list.get(j))); } int j = minDistance(distance); float[] f={0,0}; f[0]=k; f[1]=distance[j]; ltemp.add(f); } } int m=maxDistance(ltemp); temp=(int) ltemp.get(m)[0]; list.add(temp); //---------------------- randoms[i] = temp; } for (int i = 0; i < numOfCluster; i++) { center.add(dataSet.get(randoms[i]));// 生成初始化中心点集 } return center; } /** * 初始化簇集合 * * @return 一个分为k簇的空数据的簇集合 */ private ArrayList<ArrayList<float[]>> initCluster() { ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>(); for (int i = 0; i < numOfCluster; i++) { cluster.add(new ArrayList<float[]>()); } return cluster; } /** * 计算两个点之间的距离 * * @param element * 点1 * @param center * 点2 * @return 距离 */ private float distance(float[] element, float[] center) { float distance = 0.0f; float x = element[0] - center[0]; float y = element[1] - center[1]; float z = x * x + y * y; distance = (float) Math.sqrt(z); return distance; } /** * 获取距离集合中最小距离的位置 * * @param distance * 距离数组 * @return 最小距离在距离数组中的位置 */ private int minDistance(float[] distance) { float minDistance = distance[0]; int minLocation = 0; for (int i = 1; i < distance.length; i++) { if (distance[i] <= minDistance) { minDistance = distance[i]; minLocation = i; } } return minLocation; } /** * 获取距离集合中最小距离的最大的位置 * * @param distance * 各点最小距离数组 * @return 各点最小距离在距离数组中的位置 */ private int maxDistance(ArrayList<float[]> distance) { float[] maxDistance = distance.get(0); int maxLocation = 0; for (int i = 1; i < distance.size(); i++) { if (distance.get(i)[1] >= maxDistance[1]) { maxDistance = distance.get(i); maxLocation = i; } } return maxLocation; } /** * 核心,将当前元素放到最小距离的簇中 */ private void clusterSet() { float[] distance = new float[numOfCluster]; for (int i = 0; i < dataSetLength; i++) { for (int j = 0; j < numOfCluster; j++) { //计算数据集点与所有中心点的距离 distance[j] = distance(dataSet.get(i), center.get(j)); } int j = minDistance(distance); // 核心,将当前元素放到最小距离中心相关的簇中 cluster.get(j).add(dataSet.get(i)); } } /** * 求族中各点到其中心点距离的平方,即误差平方 * * @param element * 点1 * @param center * 点2 * @return 误差平方 */ private float errorSquare(float[] element, float[] center) { float x = element[0] - center[0]; float y = element[1] - center[1]; float errSquare = x * x + y * y; return errSquare; } /** * 计算一次迭代误差平方和 */ private void countRule() { float jcF = 0; for (int i = 0; i < cluster.size(); i++) { for (int j = 0; j < cluster.get(i).size(); j++) { jcF += errorSquare(cluster.get(i).get(j), center.get(i)); } } sumOfErrorSquare.add(jcF); } /** * 设置新的簇中心方法 */ private void setNewCenter() { for (int i = 0; i < numOfCluster; i++) { int n = cluster.get(i).size(); if (n != 0) { float[] newCenter = { 0, 0 }; for (int j = 0; j < n; j++) { newCenter[0] += cluster.get(i).get(j)[0]; newCenter[1] += cluster.get(i).get(j)[1]; } // 设置一个平均值 newCenter[0] = newCenter[0] / n; newCenter[1] = newCenter[1] / n; center.set(i, newCenter); } } printDataArray(center,"newCenter"); } /** * 打印数据,测试用 * * @param dataArray * 数据集 * @param dataArrayName * 数据集名称 */ public void printDataArray(ArrayList<float[]> dataArray, String dataArrayName) { for (int i = 0; i < dataArray.size(); i++) { System.out.println("print:" + dataArrayName + "[" + i + "]={" + dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}"); } System.out.println("==================================="); } /** * Kmeans算法核心过程方法 */ private void kmeans() { init(); // 循环分组,直到误差不变为止 while (true) { clusterSet(); countRule(); // 误差不变了,分组完成 if (timeOfIteration != 0) { if (sumOfErrorSquare.get(timeOfIteration) - sumOfErrorSquare.get(timeOfIteration - 1) == 0) { break; } } //设置各簇新的质心,继续迭代 setNewCenter(); timeOfIteration++; cluster.clear(); cluster = initCluster(); } System.out.println("note:the times of repeat:timeOfIteration="+timeOfIteration);//输出迭代次数 } /** * 执行算法 */ public void execute() { long startTime = System.currentTimeMillis(); System.out.println("kmeans begins"); kmeans(); long endTime = System.currentTimeMillis(); System.out.println("kmeans running time=" + (endTime - startTime) + "ms"); System.out.println("kmeans ends"); System.out.println(); } }
package com.lfy.main; import java.util.ArrayList; public class KmeansTest { public static void main(String[] args) { //初始化一个Kmean对象,设置k值 Kmeans k=new Kmeans(3); ArrayList<float[]> dataSet=new ArrayList<float[]>(); dataSet.add(new float[]{3,4}); dataSet.add(new float[]{4,4}); dataSet.add(new float[]{3,3}); dataSet.add(new float[]{4,3}); // dataSet.add(new float[]{0,2}); dataSet.add(new float[]{1,2}); dataSet.add(new float[]{0,1}); dataSet.add(new float[]{1,1}); // dataSet.add(new float[]{3,1}); dataSet.add(new float[]{3,0}); dataSet.add(new float[]{5,0}); dataSet.add(new float[]{4,0}); dataSet.add(new float[]{4,1}); //设置原始数据集 k.setDataSet(dataSet); //执行算法 k.execute(); //得到聚类结果 ArrayList<ArrayList<float[]>> cluster=k.getCluster(); //查看结果 for(int i=0;i<cluster.size();i++) { k.printDataArray(cluster.get(i), "cluster["+i+"]"); } } }