【算法】Kmeans
package com.pachira.d; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; public class Kmeans { /** * Kmeans聚类算法 * 基本思想: * 以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。 * * 过程描述: * 输入:k, data[n], eps; * (1) 选择k个初始中心点,例如c[0]=data[0],…c[k-1]=data[k-1]; * (2) 对于data[0]….data[n], 分别与c[0]…c[k-1]比较,假定与c[i]差值最少,就标记为i; * (3) 对于所有标记为i点,重新计算c[i]={ 所有标记为i的data[j]之和}/标记为i的个数; * (4) 重复(2)(3),直到所有c[i]值的变化小于给定阈值。 * * 其他说明: * 1、Kmeans的变种,其距离计算不是欧基米德距离,有可能会出现问题; * 2、海量数据聚类,欧基米德距离要比余弦相似性好(Inderjit S.Dhillon James FAN 和 Yuqiang Guan论文) * * data[n]的每个元素往往是一个向量; * */ /** * 初始化聚类中心点 * @param k 中心点数 * @param data 带聚类的数据集 * @return 中心点集 */ public static double[] getPoints(int k, int[] data){ double[] points = new double[k]; for (int i = 0; i < k; i++) { points[i] = (double)data[i]; } return points; } /** * 计算元素和每个中心点的距离,将该元素归为最小距离的中心点中 * @param points 中心点集 * @param data 元素集 * @return 聚类结果 */ public static LinkedHashMap<Double, List<Integer>> culcate(double[] points, int[] data){ LinkedHashMap<Double, List<Integer>> map = new LinkedHashMap<Double, List<Integer>>(); for (int i = 0; i < data.length; i++) { //get one point to culcate the distance int d = data[i]; double minDistance = Double.MAX_VALUE; double key = -1; for(int j = 0; j < points.length; j++){ //欧基米德距离 double tmp = Math.sqrt(Math.pow((d - points[j]), 2)); if(tmp < minDistance){ minDistance = tmp; key = points[j]; } } // System.out.println(key); if(map.containsKey(key)){ List<Integer> cus = map.get(key); cus.add(d); }else{ List<Integer> cus = new ArrayList<Integer>(); cus.add(d); map.put(key, cus); } } return map; } /** * 重置中心点 * @param 聚类结果 * @return 重置后的中心点集 */ public static double[] resetPoint(HashMap<Double, List<Integer>> map){ double[] tmp = new double[map.keySet().size()]; int index = 0; for(double key: map.keySet()){ List<Integer> val = map.get(key); double total = 0; for (int i = 0; i < val.size(); i++) { total += val.get(i); } if(val.size() == 0){ tmp[index++] = key; }else{ key = total / val.size(); tmp[index++] = key; } } return tmp; } /** * Kmeans * @param data 待聚类元素集合 * @param k 类别数目(中心点数) * @param eps 收敛阈值 * @return 聚类结果 */ public static LinkedHashMap<Double, List<Integer>> kmeans(int[] data, int k, double eps){ double[] points = getPoints(k, data); LinkedHashMap<Double, List<Integer>> tmp = null; while(true){ tmp = culcate(points, data); show(tmp); double[] tpoints = resetPoint(tmp); boolean flag = true; for (int i = 0; i < tpoints.length; i++) { if(Math.abs(points[i] - tpoints[i]) > eps){ flag = false; break; } } if(flag)break; points = tpoints; } return null; } /** * 显示聚类结果 * @param map */ public static void show(LinkedHashMap<Double, List<Integer>> map){ for (double key: map.keySet()) { System.out.println(String.format("%.2f", key) + "\t" + map.get(key)); } System.out.println("================================="); } public static void main(String[] args) { int k = 10; double eps = 0.001; int[] data = {45, 26, 45, 65, 49, 27, 44, 26, 40, 63, 35, 63, 47, 24, 65, 62, 38, 8, 43, 65, 34, 36, 80, 34, 62, 60, 54, 66, 86, 47, 73, 15, 40, 7, 12, 35, 88, 5, 9, 20, 94, 28, 70, 78, 87, 78, 43, 80, 25, 88, 46, 21, 52, 49, 36, 64, 52, 59, 24, 56, 54, 10, 81, 78, 66, 28, 53, 48, 2, 89, 44, 79, 16, 55, 27, 6, 0, 46, 76, 87, 30, 90, 40, 51, 98, 97, 55, 72, 32, 79, 61, 39, 74, 58, 55, 58, 32, 4, 76, 19}; kmeans(data, k, eps); } }