spark mllib k-means算法实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | package iie.udps.example.spark.mllib; import java.util.regex.Pattern; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.KMeans; import org.apache.spark.mllib.clustering.KMeansModel; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; /** * Example using MLLib KMeans from Java. * * spark-submit --class iie.udps.example.spark.mllib.JavaKMeans --master * yarn-cluster --num-executors 15 --driver-memory 512m --executor-memory 2g * --executor-cores 2 /home/xdf/test2.jar /user/xdf/Example.txt 10 2 */ public final class JavaKMeans { @SuppressWarnings ( "serial" ) private static class ParsePoint implements Function<String, Vector> { private static final Pattern SPACE = Pattern.compile( "," ); @Override public Vector call(String line) { String[] tok = SPACE.split(line); // 统一数据维度为3,此处没有考虑其他异常数据情况 if (tok.length < 3 ) { tok = SPACE.split(line + ",0" ); for ( int i = tok.length; i < 3 ; i++) { tok[i] = "0" ; } } if (tok.length > 3 ) { tok = SPACE.split( "0,0,0" ); } double [] point = new double [tok.length]; for ( int i = 0 ; i < tok.length; ++i) { point[i] = Double.parseDouble(tok[i]); } return Vectors.dense(point); } } public static void main(String[] args) { if (args.length < 3 ) { System.err .println( "Usage: JavaKMeans <input_file> <k> <max_iterations> [<runs>]" ); System.exit( 1 ); } String inputFile = args[ 0 ]; // 要读取的文件 int k = Integer.parseInt(args[ 1 ]); // 聚类个数 int iterations = Integer.parseInt(args[ 2 ]); // 迭代次数 int runs = 1 ; // 运行算法次数 if (args.length >= 4 ) { runs = Integer.parseInt(args[ 3 ]); } SparkConf sparkConf = new SparkConf().setAppName( "JavaKMeans" ); // sparkConf.set("spark.default.parallelism", "4"); // sparkConf.set("spark.akka.frameSize", "1024"); System.setProperty( "dfs.client.block.write.replace-datanode-on-failure.enable" , "true" ); System.setProperty( "dfs.client.block.write.replace-datanode-on-failure.policy" , "never" ); // sparkConf.set( // "dfs.client.block.write.replace-datanode-on-failure.enable", // "true"); // sparkConf.set( // "dfs.client.block.write.replace-datanode-on-failure.policy", // "never"); JavaSparkContext sc = new JavaSparkContext(sparkConf); // 指定文件分片数 JavaRDD<String> lines = sc.textFile(inputFile, 2400 ); // ,1264 , 1872,2400 JavaRDD<Vector> points = lines.map( new ParsePoint()); KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL()); // System.out.println("Vector 98, 345, 90 belongs to clustering :" // + model.predict(Vectors.dense(98, 345, 90))); // System.out.println("Vector 748, 965, 202 belongs to clustering :" // + model.predict(Vectors.dense(748, 965, 202))); // System.out.println("Vector 310, 554, 218 belongs to clustering :" // + model.predict(Vectors.dense(310, 554, 218))); System.out.println( "Cluster centers:" ); for (Vector center : model.clusterCenters()) { System.out.println( " " + center); } double cost = model.computeCost(points.rdd()); System.out.println( "Cost: " + cost); sc.stop(); } } |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 分享 3 个 .NET 开源的文件压缩处理库,助力快速实现文件压缩解压功能!
· Ollama——大语言模型本地部署的极速利器
· DeepSeek如何颠覆传统软件测试?测试工程师会被淘汰吗?