spark mllib k-means算法实现
package iie.udps.example.spark.mllib; import java.util.regex.Pattern; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.KMeans; import org.apache.spark.mllib.clustering.KMeansModel; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; /** * Example using MLLib KMeans from Java. * * spark-submit --class iie.udps.example.spark.mllib.JavaKMeans --master * yarn-cluster --num-executors 15 --driver-memory 512m --executor-memory 2g * --executor-cores 2 /home/xdf/test2.jar /user/xdf/Example.txt 10 2 */ public final class JavaKMeans { @SuppressWarnings("serial") private static class ParsePoint implements Function<String, Vector> { private static final Pattern SPACE = Pattern.compile(","); @Override public Vector call(String line) { String[] tok = SPACE.split(line); // 统一数据维度为3,此处没有考虑其他异常数据情况 if (tok.length < 3) { tok = SPACE.split(line + ",0"); for (int i = tok.length; i < 3; i++) { tok[i] = "0"; } } if (tok.length > 3) { tok = SPACE.split("0,0,0"); } double[] point = new double[tok.length]; for (int i = 0; i < tok.length; ++i) { point[i] = Double.parseDouble(tok[i]); } return Vectors.dense(point); } } public static void main(String[] args) { if (args.length < 3) { System.err .println("Usage: JavaKMeans <input_file> <k> <max_iterations> [<runs>]"); System.exit(1); } String inputFile = args[0]; // 要读取的文件 int k = Integer.parseInt(args[1]); // 聚类个数 int iterations = Integer.parseInt(args[2]); // 迭代次数 int runs = 1; // 运行算法次数 if (args.length >= 4) { runs = Integer.parseInt(args[3]); } SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans"); // sparkConf.set("spark.default.parallelism", "4"); // sparkConf.set("spark.akka.frameSize", "1024"); System.setProperty( "dfs.client.block.write.replace-datanode-on-failure.enable", "true"); System.setProperty( "dfs.client.block.write.replace-datanode-on-failure.policy", "never"); // sparkConf.set( // "dfs.client.block.write.replace-datanode-on-failure.enable", // "true"); // sparkConf.set( // "dfs.client.block.write.replace-datanode-on-failure.policy", // "never"); JavaSparkContext sc = new JavaSparkContext(sparkConf); // 指定文件分片数 JavaRDD<String> lines = sc.textFile(inputFile,2400);// ,1264 , 1872,2400 JavaRDD<Vector> points = lines.map(new ParsePoint()); KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL()); // System.out.println("Vector 98, 345, 90 belongs to clustering :" // + model.predict(Vectors.dense(98, 345, 90))); // System.out.println("Vector 748, 965, 202 belongs to clustering :" // + model.predict(Vectors.dense(748, 965, 202))); // System.out.println("Vector 310, 554, 218 belongs to clustering :" // + model.predict(Vectors.dense(310, 554, 218))); System.out.println("Cluster centers:"); for (Vector center : model.clusterCenters()) { System.out.println(" " + center); } double cost = model.computeCost(points.rdd()); System.out.println("Cost: " + cost); sc.stop(); } }