flink KMeans算法实现

更正:之前发的有两个错误。

1、K均值聚类算法

百度解释:k均值聚类算法(k-means clustering algorithm)是一种迭代求解的聚类分析算法,其步骤是随机选取K个对象作为初始的聚类中心,
然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。
聚类中心以及分配给它们的对象就代表一个聚类。每分配一个样本,聚类的聚类中心会根据聚类中现有的对象被重新计算。
这个过程将不断重复直到满足某个终止条件。
终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。

2、二维坐标点POJO

public class Point {
    public double x, y;

    public Point() {}

    public Point(double x, double y) {
        this.x = x;
        this.y = y;
    }

    public Point add(Point other) {
        x += other.x;
        y += other.y;
        return this;
    }

    //取均值使用
    public Point div(long val) {
        x /= val;
        y /= val;
        return this;
    }

    //欧几里得距离
    public double euclideanDistance(Point other) {
        return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y));
    }

    public void clear() {
        x = y = 0.0;
    }

    @Override
    public String toString() {
        return x + " " + y;
    }
}

二维聚类中心POJO

public class Centroid extends Point{
    public int id;

    public Centroid() {}

    public Centroid(int id, double x, double y) {
        super(x, y);
        this.id = id;
    }

    public Centroid(int id, Point p) {
        super(p.x, p.y);
        this.id = id;
    }

    @Override
    public String toString() {
        return id + " " + super.toString();
    }
}

3、缺省的数据准备

public class KMeansData {
    // We have the data as object arrays so that we can also generate Scala Data Sources from it.
    public static final Object[][] CENTROIDS = new Object[][] {
            new Object[] {1, -31.85, -44.77},
            new Object[]{2, 35.16, 17.46},
            new Object[]{3, -5.16, 21.93},
            new Object[]{4, -24.06, 6.81}
    };

    public static final Object[][] POINTS = new Object[][] {
            new Object[] {-14.22, -48.01},
            new Object[] {-22.78, 37.10},
            new Object[] {56.18, -42.99},
            new Object[] {35.04, 50.29},
            new Object[] {-9.53, -46.26},
            new Object[] {-34.35, 48.25},
            new Object[] {55.82, -57.49},
            new Object[] {21.03, 54.64},
            new Object[] {-13.63, -42.26},
            new Object[] {-36.57, 32.63},
            new Object[] {50.65, -52.40},
            new Object[] {24.48, 34.04},
            new Object[] {-2.69, -36.02},
            new Object[] {-38.80, 36.58},
            new Object[] {24.00, -53.74},
            new Object[] {32.41, 24.96},
            new Object[] {-4.32, -56.92},
            new Object[] {-22.68, 29.42},
            new Object[] {59.02, -39.56},
            new Object[] {24.47, 45.07},
            new Object[] {5.23, -41.20},
            new Object[] {-23.00, 38.15},
            new Object[] {44.55, -51.50},
            new Object[] {14.62, 59.06},
            new Object[] {7.41, -56.05},
            new Object[] {-26.63, 28.97},
            new Object[] {47.37, -44.72},
            new Object[] {29.07, 51.06},
            new Object[] {0.59, -31.89},
            new Object[] {-39.09, 20.78},
            new Object[] {42.97, -48.98},
            new Object[] {34.36, 49.08},
            new Object[] {-21.91, -49.01},
            new Object[] {-46.68, 46.04},
            new Object[] {48.52, -43.67},
            new Object[] {30.05, 49.25},
            new Object[] {4.03, -43.56},
            new Object[] {-37.85, 41.72},
            new Object[] {38.24, -48.32},
            new Object[] {20.83, 57.85}
    };

    public static DataSet<Centroid> getDefaultCentroidDataSet(ExecutionEnvironment env) {
        List<Centroid> centroidList = new LinkedList<Centroid>();
        for (Object[] centroid : CENTROIDS) {
            centroidList.add(
                    new Centroid((Integer) centroid[0], (Double) centroid[1], (Double) centroid[2]));
        }
        return env.fromCollection(centroidList);
    }

    public static DataSet<Point> getDefaultPointDataSet(ExecutionEnvironment env) {
        List<Point> pointList = new LinkedList<Point>();
        for (Object[] point : POINTS) {
            pointList.add(new Point((Double) point[0], (Double) point[1]));
        }
        return env.fromCollection(pointList);
    }
}

4、KMeans聚类算法实现

/**
 * @Author: xu.dm
 * @Date: 2019/7/9 16:31
 * @Version: 1.0
 * @Description:
 * K-Means是一种迭代聚类算法,其工作原理如下:
 * K-Means给出了一组要聚类的数据点和一组初始的K聚类中心。
 * 在每次迭代中,算法计算每个数据点到每个聚类中心的距离。每个点都分配给最靠近它的集群中心。
 * 随后,每个聚类中心移动到已分配给它的所有点的中心(平均值)。移动的聚类中心被送入下一次迭代。
 * 该算法在固定次数的迭代之后终止(本例中)或者如果聚类中心在迭代中没有(显着地)移动。
 * 这是K-Means聚类算法的维基百科条目。
 * <a href="http://en.wikipedia.org/wiki/K-means_clustering">
 *
 * 此实现适用于二维数据点。
 * 它计算到集群中心的数据点分配,即每个数据点都使用它所属的最终集群(中心)的id进行注释。
 *
 * 输入文件是纯文本文件,必须格式如下:
 *
 * 数据点表示为由空白字符分隔的两个双精度值。数据点由换行符分隔。
 * 例如,"1.2 2.3\n5.3 7.2\n"给出两个数据点(x = 1.2,y = 2.3)和(x = 5.3,y = 7.2)。
 * 聚类中心由整数id和点值表示。
 * 例如,"1 6.2 3.2\n2 2.9 5.7\n"给出两个中心(id = 1,x = 6.2,y = 3.2)和(id = 2,x = 2.9,y = 5.7)。
 * 用法:KMeans --points <path> --centroids <path> --output <path> --iterations <n>
 * 如果未提供参数,则使用{@link KMeansData}中的默认数据和10次迭代运行程序。
 **/
public class KMeans {
    public static void main(String args[]) throws Exception{
        final ParameterTool params = ParameterTool.fromArgs(args);

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        env.getConfig().setGlobalJobParameters(params);

        DataSet<Point> points =getPointDataSet(params,env);
        DataSet<Centroid> centroids = getCentroidDataSet(params, env);

        IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations",10));

        DataSet<Centroid> newCentroid = points
                //计算每个点距离最近的聚类中心
                .flatMap(new SelectNearestCenter()).withBroadcastSet(loop,"centroids")
                //计算每个点到最近聚类中心的计数
                .map(new CountAppender())
                .groupBy(0).reduce(new CentroidAccumulator())
                //计算新的聚类中心
                .map(new CentroidAverager());

        //闭合迭代 loop->points->newCentroid(loop)
        DataSet<Centroid> finalCentroid = loop.closeWith(newCentroid);

        //分配所有点到新的聚类中心
        DataSet<Tuple2<Integer, Point>> clusteredPoints = points
                .flatMap(new SelectNearestCenter()).withBroadcastSet(finalCentroid,"centroids");

        // emit result
        if (params.has("output")) {
            clusteredPoints.writeAsCsv(params.get("output"), "\n", " ");

            // since file sinks are lazy, we trigger the execution explicitly
            env.execute("KMeans Example");
        } else {
            System.out.println("Printing result to stdout. Use --output to specify output path.");

            clusteredPoints.print();
        }
    }

    private static DataSet<Point> getPointDataSet(ParameterTool params,ExecutionEnvironment env){
        DataSet<Point> points;
        if(params.has("points")){
            points = env.readCsvFile(params.get("points")).fieldDelimiter(" ")
                    .pojoType(Point.class,"x","y");
        }else{
            System.out.println("Executing K-Means example with default point data set.");
            System.out.println("Use --points to specify file input.");
            points = KMeansData.getDefaultPointDataSet(env);
        }
        return points;
    }

    private static DataSet<Centroid> getCentroidDataSet(ParameterTool params,ExecutionEnvironment env){
        DataSet<Centroid> centroids;
        if(params.has("centroids")){
            centroids = env.readCsvFile(params.get("centroids")).fieldDelimiter(" ")
                    .pojoType(Centroid.class,"id","x","y");
        }else{
            System.out.println("Executing K-Means example with default centroid data set.");
            System.out.println("Use --centroids to specify file input.");
            centroids = KMeansData.getDefaultCentroidDataSet(env);
        }
        return centroids;
    }

    /** Determines the closest cluster center for a data point.
     * 找到最近的聚类中心
     * */
    @FunctionAnnotation.ForwardedFields("*->1")
    public static final class SelectNearestCenter extends RichFlatMapFunction<Point, Tuple2<Integer,Point>>{
        private Collection<Centroid> centroids;

        /** Reads the centroid values from a broadcast variable into a collection.
         * 从广播变量里读取聚类中心点数据到集合中
         * */
        @Override
        public void open(Configuration parameters) throws Exception {
            this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
        }

        @Override
        public void flatMap(Point point, Collector<Tuple2<Integer, Point>> out) throws Exception {
            double minDistance = Double.MAX_VALUE;
            int closestCentroidId = -1;

            //检查所有聚类中心
            for(Centroid centroid:centroids){
                //计算点到聚类中心的距离
                double distance = point.euclideanDistance(centroid);

                //更新最小距离
                if(distance<minDistance){
                    minDistance = distance;
                    closestCentroidId = centroid.id;
                }
            }
            out.collect(new Tuple2<>(closestCentroidId,point));
        }
    }

    /**
     * 增加一个计数变量
     */
    @FunctionAnnotation.ForwardedFields("f0;f1")
    public static final class CountAppender implements MapFunction<Tuple2<Integer,Point>, Tuple3<Integer,Point,Long>>{
        @Override
        public Tuple3<Integer, Point, Long> map(Tuple2<Integer, Point> value) throws Exception {
            return new Tuple3<>(value.f0,value.f1,1L);
        }
    }

    /**
     * 合计坐标点和计数,下一步重新平均
     */
    @FunctionAnnotation.ForwardedFields("0")
    public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer,Point,Long>>{
        @Override
        public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> value1, Tuple3<Integer, Point, Long> value2) throws Exception {
            return Tuple3.of(value1.f0,value1.f1.add(value2.f1),value1.f2+value2.f2);
        }
    }

    /**
     *重新计算聚类中心
     */
    @FunctionAnnotation.ForwardedFields("0->id")
    public static final class CentroidAverager implements MapFunction<Tuple3<Integer,Point,Long>,Centroid>{
        @Override
        public Centroid map(Tuple3<Integer,Point,Long> value) throws Exception {
            return new Centroid(value.f0,value.f1.div(value.f2));
        }
    }


}

 

 

 

 
posted @ 2019-07-09 21:36  我是属车的  阅读(1818)  评论(6编辑  收藏  举报