spark mllib k-means算法实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package iie.udps.example.spark.mllib;
 
import java.util.regex.Pattern;
 
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
 
/**
 * Example using MLLib KMeans from Java.
 *
 * spark-submit --class iie.udps.example.spark.mllib.JavaKMeans --master
 * yarn-cluster --num-executors 15 --driver-memory 512m --executor-memory 2g
 * --executor-cores 2 /home/xdf/test2.jar /user/xdf/Example.txt 10 2
 */
public final class JavaKMeans {
 
    @SuppressWarnings("serial")
    private static class ParsePoint implements Function<String, Vector> {
        private static final Pattern SPACE = Pattern.compile(",");
 
        @Override
        public Vector call(String line) {
            String[] tok = SPACE.split(line);
            // 统一数据维度为3,此处没有考虑其他异常数据情况
            if (tok.length < 3) {
                tok = SPACE.split(line + ",0");
                for (int i = tok.length; i < 3; i++) {
                    tok[i] = "0";
                }
            }
            if (tok.length > 3) {
                tok = SPACE.split("0,0,0");
            }
            double[] point = new double[tok.length];
            for (int i = 0; i < tok.length; ++i) {
                point[i] = Double.parseDouble(tok[i]);
            }
            return Vectors.dense(point);
        }
 
    }
 
    public static void main(String[] args) {
        if (args.length < 3) {
            System.err
                    .println("Usage: JavaKMeans <input_file> <k> <max_iterations> [<runs>]");
            System.exit(1);
        }
        String inputFile = args[0]; // 要读取的文件
        int k = Integer.parseInt(args[1]); // 聚类个数
        int iterations = Integer.parseInt(args[2]); // 迭代次数
        int runs = 1; // 运行算法次数
 
        if (args.length >= 4) {
            runs = Integer.parseInt(args[3]);
        }
        SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans");
        // sparkConf.set("spark.default.parallelism", "4");
        // sparkConf.set("spark.akka.frameSize", "1024");
        System.setProperty(
                "dfs.client.block.write.replace-datanode-on-failure.enable",
                "true");
        System.setProperty(
                "dfs.client.block.write.replace-datanode-on-failure.policy",
                "never");
        // sparkConf.set(
        // "dfs.client.block.write.replace-datanode-on-failure.enable",
        // "true");
        // sparkConf.set(
        // "dfs.client.block.write.replace-datanode-on-failure.policy",
        // "never");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        // 指定文件分片数
        JavaRDD<String> lines = sc.textFile(inputFile,2400);// ,1264 , 1872,2400
        JavaRDD<Vector> points = lines.map(new ParsePoint());
 
        KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs,
                KMeans.K_MEANS_PARALLEL());
 
//       System.out.println("Vector 98, 345, 90 belongs to clustering :"
//       + model.predict(Vectors.dense(98, 345, 90)));
//       System.out.println("Vector 748, 965, 202 belongs to clustering :"
//       + model.predict(Vectors.dense(748, 965, 202)));
//       System.out.println("Vector 310, 554, 218 belongs to clustering :"
//       + model.predict(Vectors.dense(310, 554, 218)));
 
        System.out.println("Cluster centers:");
        for (Vector center : model.clusterCenters()) {
            System.out.println(" " + center);
 
        }
        double cost = model.computeCost(points.rdd());
        System.out.println("Cost: " + cost);
 
        sc.stop();
    }
}

  

posted on   XIAO的博客  阅读(1034)  评论(0编辑  收藏  举报

编辑推荐:
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
阅读排行:
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 分享 3 个 .NET 开源的文件压缩处理库,助力快速实现文件压缩解压功能!
· Ollama——大语言模型本地部署的极速利器
· DeepSeek如何颠覆传统软件测试?测试工程师会被淘汰吗?

导航

统计

点击右上角即可分享
微信分享提示