spark MLlib collaborativeFilltering学习

 1 package ML.collaborativeFilltering;
 2 
 3 import org.apache.spark.SparkConf;
 4 import org.apache.spark.api.java.JavaDoubleRDD;
 5 import org.apache.spark.api.java.JavaPairRDD;
 6 import org.apache.spark.api.java.JavaRDD;
 7 import org.apache.spark.api.java.JavaSparkContext;
 8 import org.apache.spark.api.java.function.Function;
 9 import org.apache.spark.mllib.recommendation.ALS;
10 import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
11 import org.apache.spark.mllib.recommendation.Rating;
12 import scala.Tuple2;
13 
14 /**
15  * TODO
16  *
17  * @ClassName: example
18  * @author: DingH
19  * @since: 2019/4/10 16:03
20  */
21 public class example {
22     public static void main(String[] args) {
23         SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example");
24         JavaSparkContext jsc = new JavaSparkContext(conf);
25 
26         // Load and parse the data
27         String path = "D:\\IdeaProjects\\SimpleApp\\src\\main\\resources\\data\\mllib\\als\\test.data";
28         JavaRDD<String> data = jsc.textFile(path);
29         JavaRDD<Rating> ratings = data.map(new Function<String, Rating>() {
30             public Rating call(String s) {
31                 String[] sarray = s.split(",");
32                 return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2]));
33             }
34           }
35         );
36         int ranks = 10;
37         int numIterations = 10;
38         MatrixFactorizationModel model = ALS.train(ratings.rdd(), ranks, numIterations);
39 
40         JavaRDD<Tuple2<Object, Object>> userProducts = ratings.map(new Function<Rating, Tuple2<Object, Object>>() {
41             public Tuple2<Object, Object> call(Rating r) {
42               return new Tuple2<Object, Object>(r.user(), r.product());
43             }
44           }
45         );
46         JavaPairRDD<Tuple2<Integer, Integer>, Double> predictions = JavaPairRDD.fromJavaRDD(model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map(
47             new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() {
48               public Tuple2<Tuple2<Integer, Integer>, Double> call(Rating r){
49                 return new Tuple2<Tuple2<Integer, Integer>, Double>(
50                   new Tuple2<Integer, Integer>(r.user(), r.product()), r.rating());
51               }
52             }
53           ));
54 
55         JavaRDD<Tuple2<Double, Double>> ratesAndPreds = JavaPairRDD.fromJavaRDD(ratings.map(
56             new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() {
57               public Tuple2<Tuple2<Integer, Integer>, Double> call(Rating r){
58                 return new Tuple2<Tuple2<Integer, Integer>, Double>(
59                   new Tuple2<Integer, Integer>(r.user(), r.product()), r.rating());
60               }
61             }
62           )).join(predictions).values();
63 
64         double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map(
65           new Function<Tuple2<Double, Double>, Object>() {
66             public Object call(Tuple2<Double, Double> pair) {
67               Double err = pair._1() - pair._2();
68               return err * err;
69             }
70           }
71         ).rdd()).mean();
72 
73         System.out.println("Mean Squared Error = " + MSE);
74 
75 
76 
77 
78     }
79 }

 

posted @ 2019-04-10 16:23  _Meditation  阅读(254)  评论(0编辑  收藏  举报