k均值算法
2013-04-16 19:55 youxin 阅读(942) 评论(0) 编辑 收藏 举报K均值算法是聚类分析中较常用的一种算法,基本思想如下:
首先,随机地选择k个对象,每个对象代表一个簇的初始值或中心,对剩余的每个对象,根据其与各个簇均值的距离,将它指派到最相近的簇,然后计算每个簇的新均值。这个过程一直重复,直到准则函数收敛。
关于距离,有几种不同的距离公式:
求点群中心的算法
一般来说,求点群中心点的算法你可以很简的使用各个点的X/Y坐标的平均值。不过,我这里想告诉大家另三个求中心点的的公式:
1)Minkowski Distance 公式 —— λ 可以随意取值,可以是负数,也可以是正数,或是无穷大。
2)Euclidean Distance 公式 —— 也就是第一个公式 λ=2 的情况
3)CityBlock Distance 公式 —— 也就是第一个公式 λ=1 的情况
算法实现:
View Code
package kmeans; public class Data { private double mX=0; private double mY=0; private int mCluster=0; public Data() { return; } public Data(double x,double y) { this.X(x); this.Y(y); return; } public void X(double x) { this.mX=x; } public void Y(double y) { this.mY=y; } public double X() { return this.mX; } public double Y() { return this.mY; } public void cluster(int clusterNumber) { this.mCluster=clusterNumber; } public int cluster() { return this.mCluster; } }
View Code
package kmeans; public class Centroid { private double mX = 0.0; private double mY = 0.0; public Centroid() { return; } public Centroid(double newX, double newY) { this.mX = newX; this.mY = newY; return; } public void X(double newX) { this.mX = newX; return; } public double X() { return this.mX; } public void Y(double newY) { this.mY = newY; return; } public double Y() { return this.mY; } }
package kmeans; import java.util.ArrayList; public class KMeans { public static final int NUM_CLUSTERS=2;//TOTAL CLUSTERS public static final int TOTAL_DATA=7;//total data points public static final double SAMPLES[][]=new double[][]{ {1.0, 1.0}, {1.5, 2.0}, {3.0, 4.0}, {5.0, 7.0}, {3.5, 5.0}, {4.5, 5.0}, {3.5, 4.5} }; public ArrayList<Data> dataSet=new ArrayList<Data>(); public ArrayList<Centroid> centroids=new ArrayList<Centroid>(); public void init() { System.out.println("centroids initialized at:"); centroids.add(new Centroid(1.0,1.0));//lowest set centroids.add(new Centroid(5.0, 7.0)); // highest set. System.out.println(" ("+centroids.get(0).X()+", " + centroids.get(0).Y() + ")"); System.out.println(" (" + centroids.get(1).X() + ", " + centroids.get(1).Y() + ")"); System.out.print("\n"); } public void kMeanCluster() { final double bigNumber=Math.pow(10,10); //// some big number that's sure to be larger than our data range. double minimum=bigNumber;// // The minimum value to beat. double distance=0.0;// // The current minimum value. int sampleNumber=0; int cluster=0; boolean isStillMoving=true; Data newData=null; // Add in new data, one at a time, recalculating centroids with each new one. while(dataSet.size()<TOTAL_DATA) { newData=new Data(SAMPLES[sampleNumber][0], SAMPLES[sampleNumber][1]); dataSet.add(newData); minimum=bigNumber; for(int i=0;i<NUM_CLUSTERS;i++) { distance=dist(newData, centroids.get(i)); if(distance<minimum) { minimum=distance; cluster=i; } } newData.cluster(cluster); //calculate new centroids for(int i=0;i<NUM_CLUSTERS;i++) { int totalX=0; int totalY=0; int totalInCluster=0; for(int j=0;j<dataSet.size();j++) { if(dataSet.get(j).cluster()==i) { totalX+=dataSet.get(j).X(); totalY+=dataSet.get(j).Y(); totalInCluster++; } } if(totalInCluster > 0)//有可能为0吗 有 { centroids.get(i).X(totalX / totalInCluster); centroids.get(i).Y(totalY / totalInCluster); } }//end for(int i=0;i<NUM_CLUSTERS;i++) sampleNumber++; }//end while while(isStillMoving) { //calculate new centroids for(int i = 0; i < NUM_CLUSTERS; i++) { int totalX = 0; int totalY = 0; int totalInCluster = 0; for(int j = 0; j < dataSet.size(); j++) { if(dataSet.get(j).cluster() == i){ totalX += dataSet.get(j).X(); totalY += dataSet.get(j).Y(); totalInCluster++; } } if(totalInCluster > 0){ centroids.get(i).X(totalX / totalInCluster); centroids.get(i).Y(totalY / totalInCluster); } } // Assign all data to the new centroids isStillMoving = false; for(int i=0;i<dataSet.size();i++) { Data tempData=dataSet.get(i); minimum=bigNumber; for(int j=0;j<NUM_CLUSTERS;j++) { distance=dist(tempData,centroids.get(j)); if(distance<minimum) { minimum=distance; cluster=j; } } tempData.cluster(cluster); if(tempData.cluster()!=cluster) { tempData.cluster(cluster); isStillMoving=true; } } } }//end function /** * // Calculate Euclidean distance. * @param d - Data object. * @param c - Centroid object. * @return - double value. */ private static double dist(Data d, Centroid c) { return Math.sqrt(Math.pow((c.Y() - d.Y()), 2) + Math.pow((c.X() - d.X()), 2)); } }
View Code
package kmeans; public class test { /** * @param args */ public static void main(String[] args) { // TODO Auto-generated method stub KMeans k=new KMeans(); k.init(); k.kMeanCluster(); //print out clustering results for(int i=0;i<KMeans.NUM_CLUSTERS;i++) { System.out.println("Cluster " + i + " includes:"); for(int j=0;j<KMeans.TOTAL_DATA;j++) { if(k.dataSet.get(j).cluster()==i) { System.out.println(k.dataSet.get(j).X() + ", " +k.dataSet.get(j).Y() ); } } System.out.println(); } System.out.println("Centroids finalized at:"); for(int i = 0; i < KMeans.NUM_CLUSTERS; i++) { System.out.println(" (" + k.centroids.get(i).X() + ", " + k.centroids.get(i).Y()); } System.out.print("\n"); return; } }
ref:http://mnemstudio.org/clustering-k-means-example-1.htm (neat)
http://blog.jobbole.com/23157/
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步