flink KMeans算法实现
更正:之前发的有两个错误。
1、K均值聚类算法
百度解释:k均值聚类算法(k-means clustering algorithm)是一种迭代求解的聚类分析算法,其步骤是随机选取K个对象作为初始的聚类中心,
然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。
聚类中心以及分配给它们的对象就代表一个聚类。每分配一个样本,聚类的聚类中心会根据聚类中现有的对象被重新计算。
这个过程将不断重复直到满足某个终止条件。
终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。
2、二维坐标点POJO
public class Point { public double x, y; public Point() {} public Point(double x, double y) { this.x = x; this.y = y; } public Point add(Point other) { x += other.x; y += other.y; return this; } //取均值使用 public Point div(long val) { x /= val; y /= val; return this; } //欧几里得距离 public double euclideanDistance(Point other) { return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y)); } public void clear() { x = y = 0.0; } @Override public String toString() { return x + " " + y; } }
二维聚类中心POJO
public class Centroid extends Point{ public int id; public Centroid() {} public Centroid(int id, double x, double y) { super(x, y); this.id = id; } public Centroid(int id, Point p) { super(p.x, p.y); this.id = id; } @Override public String toString() { return id + " " + super.toString(); } }
3、缺省的数据准备
public class KMeansData { // We have the data as object arrays so that we can also generate Scala Data Sources from it. public static final Object[][] CENTROIDS = new Object[][] { new Object[] {1, -31.85, -44.77}, new Object[]{2, 35.16, 17.46}, new Object[]{3, -5.16, 21.93}, new Object[]{4, -24.06, 6.81} }; public static final Object[][] POINTS = new Object[][] { new Object[] {-14.22, -48.01}, new Object[] {-22.78, 37.10}, new Object[] {56.18, -42.99}, new Object[] {35.04, 50.29}, new Object[] {-9.53, -46.26}, new Object[] {-34.35, 48.25}, new Object[] {55.82, -57.49}, new Object[] {21.03, 54.64}, new Object[] {-13.63, -42.26}, new Object[] {-36.57, 32.63}, new Object[] {50.65, -52.40}, new Object[] {24.48, 34.04}, new Object[] {-2.69, -36.02}, new Object[] {-38.80, 36.58}, new Object[] {24.00, -53.74}, new Object[] {32.41, 24.96}, new Object[] {-4.32, -56.92}, new Object[] {-22.68, 29.42}, new Object[] {59.02, -39.56}, new Object[] {24.47, 45.07}, new Object[] {5.23, -41.20}, new Object[] {-23.00, 38.15}, new Object[] {44.55, -51.50}, new Object[] {14.62, 59.06}, new Object[] {7.41, -56.05}, new Object[] {-26.63, 28.97}, new Object[] {47.37, -44.72}, new Object[] {29.07, 51.06}, new Object[] {0.59, -31.89}, new Object[] {-39.09, 20.78}, new Object[] {42.97, -48.98}, new Object[] {34.36, 49.08}, new Object[] {-21.91, -49.01}, new Object[] {-46.68, 46.04}, new Object[] {48.52, -43.67}, new Object[] {30.05, 49.25}, new Object[] {4.03, -43.56}, new Object[] {-37.85, 41.72}, new Object[] {38.24, -48.32}, new Object[] {20.83, 57.85} }; public static DataSet<Centroid> getDefaultCentroidDataSet(ExecutionEnvironment env) { List<Centroid> centroidList = new LinkedList<Centroid>(); for (Object[] centroid : CENTROIDS) { centroidList.add( new Centroid((Integer) centroid[0], (Double) centroid[1], (Double) centroid[2])); } return env.fromCollection(centroidList); } public static DataSet<Point> getDefaultPointDataSet(ExecutionEnvironment env) { List<Point> pointList = new LinkedList<Point>(); for (Object[] point : POINTS) { pointList.add(new Point((Double) point[0], (Double) point[1])); } return env.fromCollection(pointList); } }
4、KMeans聚类算法实现
/** * @Author: xu.dm * @Date: 2019/7/9 16:31 * @Version: 1.0 * @Description: * K-Means是一种迭代聚类算法,其工作原理如下: * K-Means给出了一组要聚类的数据点和一组初始的K聚类中心。 * 在每次迭代中,算法计算每个数据点到每个聚类中心的距离。每个点都分配给最靠近它的集群中心。 * 随后,每个聚类中心移动到已分配给它的所有点的中心(平均值)。移动的聚类中心被送入下一次迭代。 * 该算法在固定次数的迭代之后终止(本例中)或者如果聚类中心在迭代中没有(显着地)移动。 * 这是K-Means聚类算法的维基百科条目。 * <a href="http://en.wikipedia.org/wiki/K-means_clustering"> * * 此实现适用于二维数据点。 * 它计算到集群中心的数据点分配,即每个数据点都使用它所属的最终集群(中心)的id进行注释。 * * 输入文件是纯文本文件,必须格式如下: * * 数据点表示为由空白字符分隔的两个双精度值。数据点由换行符分隔。 * 例如,"1.2 2.3\n5.3 7.2\n"给出两个数据点(x = 1.2,y = 2.3)和(x = 5.3,y = 7.2)。 * 聚类中心由整数id和点值表示。 * 例如,"1 6.2 3.2\n2 2.9 5.7\n"给出两个中心(id = 1,x = 6.2,y = 3.2)和(id = 2,x = 2.9,y = 5.7)。 * 用法:KMeans --points <path> --centroids <path> --output <path> --iterations <n> * 如果未提供参数,则使用{@link KMeansData}中的默认数据和10次迭代运行程序。 **/ public class KMeans { public static void main(String args[]) throws Exception{ final ParameterTool params = ParameterTool.fromArgs(args); final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.getConfig().setGlobalJobParameters(params); DataSet<Point> points =getPointDataSet(params,env); DataSet<Centroid> centroids = getCentroidDataSet(params, env); IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations",10)); DataSet<Centroid> newCentroid = points //计算每个点距离最近的聚类中心 .flatMap(new SelectNearestCenter()).withBroadcastSet(loop,"centroids") //计算每个点到最近聚类中心的计数 .map(new CountAppender()) .groupBy(0).reduce(new CentroidAccumulator()) //计算新的聚类中心 .map(new CentroidAverager()); //闭合迭代 loop->points->newCentroid(loop) DataSet<Centroid> finalCentroid = loop.closeWith(newCentroid); //分配所有点到新的聚类中心 DataSet<Tuple2<Integer, Point>> clusteredPoints = points .flatMap(new SelectNearestCenter()).withBroadcastSet(finalCentroid,"centroids"); // emit result if (params.has("output")) { clusteredPoints.writeAsCsv(params.get("output"), "\n", " "); // since file sinks are lazy, we trigger the execution explicitly env.execute("KMeans Example"); } else { System.out.println("Printing result to stdout. Use --output to specify output path."); clusteredPoints.print(); } } private static DataSet<Point> getPointDataSet(ParameterTool params,ExecutionEnvironment env){ DataSet<Point> points; if(params.has("points")){ points = env.readCsvFile(params.get("points")).fieldDelimiter(" ") .pojoType(Point.class,"x","y"); }else{ System.out.println("Executing K-Means example with default point data set."); System.out.println("Use --points to specify file input."); points = KMeansData.getDefaultPointDataSet(env); } return points; } private static DataSet<Centroid> getCentroidDataSet(ParameterTool params,ExecutionEnvironment env){ DataSet<Centroid> centroids; if(params.has("centroids")){ centroids = env.readCsvFile(params.get("centroids")).fieldDelimiter(" ") .pojoType(Centroid.class,"id","x","y"); }else{ System.out.println("Executing K-Means example with default centroid data set."); System.out.println("Use --centroids to specify file input."); centroids = KMeansData.getDefaultCentroidDataSet(env); } return centroids; } /** Determines the closest cluster center for a data point. * 找到最近的聚类中心 * */ @FunctionAnnotation.ForwardedFields("*->1") public static final class SelectNearestCenter extends RichFlatMapFunction<Point, Tuple2<Integer,Point>>{ private Collection<Centroid> centroids; /** Reads the centroid values from a broadcast variable into a collection. * 从广播变量里读取聚类中心点数据到集合中 * */ @Override public void open(Configuration parameters) throws Exception { this.centroids = getRuntimeContext().getBroadcastVariable("centroids"); } @Override public void flatMap(Point point, Collector<Tuple2<Integer, Point>> out) throws Exception { double minDistance = Double.MAX_VALUE; int closestCentroidId = -1; //检查所有聚类中心 for(Centroid centroid:centroids){ //计算点到聚类中心的距离 double distance = point.euclideanDistance(centroid); //更新最小距离 if(distance<minDistance){ minDistance = distance; closestCentroidId = centroid.id; } } out.collect(new Tuple2<>(closestCentroidId,point)); } } /** * 增加一个计数变量 */ @FunctionAnnotation.ForwardedFields("f0;f1") public static final class CountAppender implements MapFunction<Tuple2<Integer,Point>, Tuple3<Integer,Point,Long>>{ @Override public Tuple3<Integer, Point, Long> map(Tuple2<Integer, Point> value) throws Exception { return new Tuple3<>(value.f0,value.f1,1L); } } /** * 合计坐标点和计数,下一步重新平均 */ @FunctionAnnotation.ForwardedFields("0") public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer,Point,Long>>{ @Override public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> value1, Tuple3<Integer, Point, Long> value2) throws Exception { return Tuple3.of(value1.f0,value1.f1.add(value2.f1),value1.f2+value2.f2); } } /** *重新计算聚类中心 */ @FunctionAnnotation.ForwardedFields("0->id") public static final class CentroidAverager implements MapFunction<Tuple3<Integer,Point,Long>,Centroid>{ @Override public Centroid map(Tuple3<Integer,Point,Long> value) throws Exception { return new Centroid(value.f0,value.f1.div(value.f2)); } } }