代码改变世界

k均值算法

  youxin  阅读(942)  评论(0编辑  收藏  举报

K均值算法是聚类分析中较常用的一种算法,基本思想如下:

首先,随机地选择k个对象,每个对象代表一个簇的初始值或中心,对剩余的每个对象,根据其与各个簇均值的距离,将它指派到最相近的簇,然后计算每个簇的新均值。这个过程一直重复,直到准则函数收敛。

关于距离,有几种不同的距离公式:

求点群中心的算法

一般来说,求点群中心点的算法你可以很简的使用各个点的X/Y坐标的平均值。不过,我这里想告诉大家另三个求中心点的的公式:

1)Minkowski Distance 公式 —— λ 可以随意取值,可以是负数,也可以是正数,或是无穷大。

 

2)Euclidean Distance 公式 —— 也就是第一个公式 λ=2 的情况

 

3)CityBlock Distance 公式 —— 也就是第一个公式 λ=1 的情况

 

算法实现:

复制代码
View Code
package kmeans;

public class Data {
    private double mX=0;
    private double mY=0;
    private int mCluster=0;
    
    public Data()
    {
        return;
    }
    public Data(double x,double y)
    {
        this.X(x);
        this.Y(y);
        return;
    }
    public void X(double x)
    {
        this.mX=x;
        
    }
    public void Y(double y)
    {
        this.mY=y;
    }
    public double X()
    {
        return this.mX;
    }
    public double Y()
    {
        return this.mY;
    }
    public void cluster(int clusterNumber)
    {
        this.mCluster=clusterNumber;
        
    }
    public int  cluster()
    {
        return this.mCluster;
    }

}
复制代码
复制代码
View Code
package kmeans;

public class Centroid {
    private double mX = 0.0;
    private double mY = 0.0;
    
    public Centroid()
    {
        return;
    }
    
    public Centroid(double newX, double newY)
    {
        this.mX = newX;
        this.mY = newY;
        return;
    }
    
    public void X(double newX)
    {
        this.mX = newX;
        return;
    }
    
    public double X()
    {
        return this.mX;
    }
    
    public void Y(double newY)
    {
        this.mY = newY;
        return;
    }
    
    public double Y()
    {
        return this.mY;
    }

}
复制代码
复制代码
package kmeans;

import java.util.ArrayList;

public class KMeans {
    public static final int NUM_CLUSTERS=2;//TOTAL CLUSTERS
    public static final int TOTAL_DATA=7;//total data points
    public  static final double SAMPLES[][]=new double[][]{
        {1.0, 1.0}, 
        {1.5, 2.0}, 
        {3.0, 4.0}, 
        {5.0, 7.0}, 
        {3.5, 5.0}, 
        {4.5, 5.0}, 
        {3.5, 4.5}
    };
    public ArrayList<Data> dataSet=new ArrayList<Data>();
    public ArrayList<Centroid> centroids=new ArrayList<Centroid>();
    
    public void init()
    {
        System.out.println("centroids initialized at:");
        centroids.add(new Centroid(1.0,1.0));//lowest set
        centroids.add(new Centroid(5.0, 7.0)); // highest set.
        System.out.println("  ("+centroids.get(0).X()+", " + centroids.get(0).Y() + ")");
        System.out.println("     (" + centroids.get(1).X() + ", " + centroids.get(1).Y() + ")");
        System.out.print("\n");
    }
    public void kMeanCluster()
    {
        final double bigNumber=Math.pow(10,10); //// some big number that's sure to be larger than our data range.
        double minimum=bigNumber;// // The minimum value to beat. 
        double distance=0.0;// // The current minimum value.
        
        int sampleNumber=0;
        int cluster=0;
        boolean isStillMoving=true;
        Data newData=null;
        
         // Add in new data, one at a time, recalculating centroids with each new one. 
        while(dataSet.size()<TOTAL_DATA)
        {
            newData=new Data(SAMPLES[sampleNumber][0], SAMPLES[sampleNumber][1]);
            dataSet.add(newData);
            minimum=bigNumber;
            for(int i=0;i<NUM_CLUSTERS;i++)
            {
                distance=dist(newData, centroids.get(i));
                if(distance<minimum)
                {
                    minimum=distance;
                    cluster=i;
                    
                }
            }
            newData.cluster(cluster);
            
            //calculate new centroids
            for(int i=0;i<NUM_CLUSTERS;i++)
            {
                int totalX=0;
                int totalY=0;
                int totalInCluster=0;
                for(int j=0;j<dataSet.size();j++)
                {
                    if(dataSet.get(j).cluster()==i)
                    {
                        totalX+=dataSet.get(j).X();
                        totalY+=dataSet.get(j).Y();
                        totalInCluster++;
                        
                    }
                }
                if(totalInCluster > 0)//有可能为0吗 有
                {
                    centroids.get(i).X(totalX / totalInCluster);
                    centroids.get(i).Y(totalY / totalInCluster);
                }
            
            }//end for(int i=0;i<NUM_CLUSTERS;i++)
            
            sampleNumber++;
        
        }//end while
        
        while(isStillMoving)
        {
            //calculate new centroids
             for(int i = 0; i < NUM_CLUSTERS; i++)
                {
                    int totalX = 0;
                    int totalY = 0;
                    int totalInCluster = 0;
                    for(int j = 0; j < dataSet.size(); j++)
                    {
                        if(dataSet.get(j).cluster() == i){
                            totalX += dataSet.get(j).X();
                            totalY += dataSet.get(j).Y();
                            totalInCluster++;
                        }
                    }
                    if(totalInCluster > 0){
                        centroids.get(i).X(totalX / totalInCluster);
                        centroids.get(i).Y(totalY / totalInCluster);
                    }
                }
             // Assign all data to the new centroids
                isStillMoving = false;
                
             for(int i=0;i<dataSet.size();i++)
             {
                 Data tempData=dataSet.get(i);
                 minimum=bigNumber;
                 for(int j=0;j<NUM_CLUSTERS;j++)
                 {
                     distance=dist(tempData,centroids.get(j));
                     if(distance<minimum)
                     {
                         minimum=distance;
                         cluster=j;
                     }
                 }
                 tempData.cluster(cluster);
                 if(tempData.cluster()!=cluster)
                 {
                     tempData.cluster(cluster);
                     isStillMoving=true;
                     
                 }
             }
             
             
        }
     }//end function
            
      /**
     * // Calculate Euclidean distance.
     * @param d - Data object.
     * @param c - Centroid object.
     * @return - double value.
     */
    private static double dist(Data d, Centroid c)
    {
        return Math.sqrt(Math.pow((c.Y() - d.Y()), 2) + Math.pow((c.X() - d.X()), 2));
    }
        
        
        
}
复制代码
复制代码
View Code
package kmeans;

public class test {

    /**
     * @param args
     */
    public static void main(String[] args) {
        // TODO Auto-generated method stub

        KMeans k=new KMeans();
        k.init();
        k.kMeanCluster();
        //print out clustering results
        for(int i=0;i<KMeans.NUM_CLUSTERS;i++)
        {
             System.out.println("Cluster " + i + " includes:");
             for(int j=0;j<KMeans.TOTAL_DATA;j++)
             {
                 if(k.dataSet.get(j).cluster()==i)
                 {
                     System.out.println(k.dataSet.get(j).X() + ", " +k.dataSet.get(j).Y() );
                     
                 }
             }
             System.out.println();
        }
        
        
         System.out.println("Centroids finalized at:");
            for(int i = 0; i < KMeans.NUM_CLUSTERS; i++)
            {
                System.out.println("     (" + k.centroids.get(i).X() + ", " + k.centroids.get(i).Y());
            }
            System.out.print("\n");
            return;
    }

}
复制代码

ref:http://mnemstudio.org/clustering-k-means-example-1.htm (neat)

http://blog.jobbole.com/23157/

 

努力加载评论中...
点击右上角即可分享
微信分享提示