【算法】Kmeans

package com.pachira.d;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;

public class Kmeans {
    /**
     *  Kmeans聚类算法
     *  基本思想:
     *  以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。
     *  
     *  过程描述:
     *  输入:k, data[n], eps;
     * (1) 选择k个初始中心点,例如c[0]=data[0],…c[k-1]=data[k-1];
     * (2) 对于data[0]….data[n], 分别与c[0]…c[k-1]比较,假定与c[i]差值最少,就标记为i;
     * (3) 对于所有标记为i点,重新计算c[i]={ 所有标记为i的data[j]之和}/标记为i的个数;
     * (4) 重复(2)(3),直到所有c[i]值的变化小于给定阈值。
     *  
     *  其他说明:
     *  1、Kmeans的变种,其距离计算不是欧基米德距离,有可能会出现问题;
     *  2、海量数据聚类,欧基米德距离要比余弦相似性好(Inderjit S.Dhillon James FAN 和 Yuqiang Guan论文)
     *  
     *  data[n]的每个元素往往是一个向量;
     *  
     */
    /**
     * 初始化聚类中心点 
     * @param k 中心点数
     * @param data 带聚类的数据集
     * @return 中心点集
     */
    public static double[] getPoints(int k, int[] data){
        double[] points = new double[k];
        for (int i = 0; i < k; i++) {
            points[i] = (double)data[i];
        }
        return points;
    }
    /**
     * 计算元素和每个中心点的距离,将该元素归为最小距离的中心点中
     * @param points 中心点集
     * @param data 元素集
     * @return 聚类结果
     */
    public static LinkedHashMap<Double, List<Integer>> culcate(double[] points, int[] data){
        LinkedHashMap<Double, List<Integer>> map = new LinkedHashMap<Double, List<Integer>>();
        for (int i = 0; i < data.length; i++) {
            //get one point to culcate the distance
            int d = data[i];
            double minDistance = Double.MAX_VALUE;
            double key = -1;
            for(int j = 0; j < points.length; j++){
                //欧基米德距离
                double tmp = Math.sqrt(Math.pow((d - points[j]), 2));
                if(tmp < minDistance){
                    minDistance = tmp;
                    key = points[j];
                }
            }
//            System.out.println(key);
            if(map.containsKey(key)){
                List<Integer> cus = map.get(key);
                cus.add(d);
            }else{
                List<Integer> cus = new ArrayList<Integer>();
                cus.add(d);
                map.put(key, cus);
            }
        }
        return map;
    }
    /**
     * 重置中心点
     * @param 聚类结果
     * @return 重置后的中心点集
     */
    public static double[] resetPoint(HashMap<Double, List<Integer>> map){
        double[] tmp = new double[map.keySet().size()];
        int index = 0;
        for(double key: map.keySet()){
            List<Integer> val = map.get(key);
            double total = 0;
            for (int i = 0; i < val.size(); i++) {
                total += val.get(i);
            }
            if(val.size() == 0){
                tmp[index++] = key;
            }else{
                key = total / val.size();
                tmp[index++] = key;
            }
        }
        return tmp;
    }
    /**
     * Kmeans
     * @param data 待聚类元素集合
     * @param k 类别数目(中心点数)
     * @param eps 收敛阈值
     * @return 聚类结果
     */
    public static LinkedHashMap<Double, List<Integer>> kmeans(int[] data, int k, double eps){
        double[] points = getPoints(k, data);
        LinkedHashMap<Double, List<Integer>> tmp = null;
        while(true){
            tmp = culcate(points, data);
            show(tmp);
            double[] tpoints = resetPoint(tmp);
            boolean flag = true;
            for (int i = 0; i < tpoints.length; i++) {
                if(Math.abs(points[i] - tpoints[i]) > eps){
                    flag = false;
                    break;
                }
            }
            if(flag)break;
            points = tpoints;
        }
        return null;
    }
    /**
     * 显示聚类结果
     * @param map
     */
    public static void show(LinkedHashMap<Double, List<Integer>> map){
        for (double key: map.keySet()) {
            System.out.println(String.format("%.2f", key) + "\t" + map.get(key));
        }
        System.out.println("=================================");
    }
    public static void main(String[] args) {
        int k = 10;
        double eps = 0.001;
        int[] data = {45, 26, 45, 65, 49, 27, 44, 26, 40, 63, 35, 63, 47, 24, 65, 62, 38, 8, 43, 65, 34, 36, 80, 34, 62, 60, 54, 66, 86, 47, 73, 15, 40, 7, 12, 35, 88, 5, 9, 20, 94, 28, 70, 78, 87, 78, 43, 80, 25, 88, 46, 21, 52, 49, 36, 64, 52, 59, 24, 56, 54, 10, 81, 78, 66, 28, 53, 48, 2, 89, 44, 79, 16, 55, 27, 6, 0, 46, 76, 87, 30, 90, 40, 51, 98, 97, 55, 72, 32, 79, 61, 39, 74, 58, 55, 58, 32, 4, 76, 19};
        kmeans(data, k, eps);
    }
}

 

posted on 2014-12-14 00:24  有个姑娘叫小芳  阅读(336)  评论(0编辑  收藏  举报