每日一题 为了工作 2020 0504 第六十二题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package data.bjsj.fjjb;
 
 
import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
 
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;
 
 
/**
 *
 * @author 雪瞳
 * @Slogan 时钟尚且前行,人怎能就此止步!
 * @Function
 *
 */
public class LogisticModel {
    public static void main(String[] args) {
         
        SparkSession session = SparkSession.builder().appName("logistic").master("local").getOrCreate();
        JavaSparkContext jsc = JavaSparkContext.fromSparkContext(session.sparkContext());
        SparkContext sc = JavaSparkContext.toSparkContext(jsc);
 
        jsc.setLogLevel("Error");
        JavaRDD<String> fileRDD = jsc.textFile("./save/rootData");
        JavaRDD<LabeledPoint> labeledPointJavaRDD = fileRDD.map(new Function<String, LabeledPoint>() {
            //"2015-11-01 20:20:16" 1.85999330468501    1.22359452534749    2.51578969727773    -0.403918740333512  0.0149184125297424      0
            @Override
            public LabeledPoint call(String line) throws Exception {
                String[] splits = line.split("\t");
                String label = splits[splits.length - 1];
 
                double[] wd = new double[splits.length - 3];
                for (int i = 0; i < wd.length; i++) {
                    wd[i] = Double.parseDouble(splits[i+1]);
                }
                LabeledPoint labeledPoint = new LabeledPoint(Double.parseDouble(label), Vectors.dense(wd));
                return labeledPoint;
            }
        });
 
         
        double[] doubles = new double[]{0.7,0.3};
        RDD<LabeledPoint> rdd = labeledPointJavaRDD.rdd();
        RDD<LabeledPoint>[] metaDataSource = rdd.randomSplit(doubles, 100L);
         
        RDD<LabeledPoint> traingData = metaDataSource[0];
        RDD<LabeledPoint> testData = metaDataSource[1];
         
        LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS();
        lr.setNumClasses(2);
        lr.setIntercept(true);
        LogisticRegressionModel model = lr.run(traingData);
        JavaRDD<Double> predictRdd = testData.toJavaRDD().map(new Function<LabeledPoint, Double>() {
            @Override
            public Double call(LabeledPoint labeledPoint) throws Exception {
                double predict = model.predict(labeledPoint.features());
                return predict;
            }
        });
        JavaPairRDD<Double, Double> zipRdd = predictRdd.zip(testData.toJavaRDD().map(new Function<LabeledPoint, Double>() {
            @Override
            public Double call(LabeledPoint labeledPoint) throws Exception {
                return labeledPoint.label();
            }
        }));
 
        Accumulator<Integer> accumulator = jsc.accumulator(0);
        zipRdd.foreach(new VoidFunction<Tuple2<Double, Double>>() {
            @Override
            public void call(Tuple2<Double, Double> tp) throws Exception {
                Double label = tp._2();
                Double predict = tp._1();
                if (Double.compare(label,predict)==0){
                    accumulator.add(1);
                }
            }
        });
        long count = zipRdd.count();
        Integer value = accumulator.value();
        System.err.println("总数目是:"+count);
        System.err.println("正确数目是:"+value);
        double rate = value / (double) count;
        System.err.println("正确率是:"+rate*100+"%");
        String path ="./save/model";
        double  stand = 80.00;
        if (Double.compare(rate,stand)<0){
            model.save(sc,path);
        }
    }
}

  

posted @   雪瞳  阅读(119)  评论(0编辑  收藏  举报
编辑推荐:
· 深入理解 Mybatis 分库分表执行原理
· 如何打造一个高并发系统?
· .NET Core GC压缩(compact_phase)底层原理浅谈
· 现代计算机视觉入门之:什么是图片特征编码
· .NET 9 new features-C#13新的锁类型和语义
阅读排行:
· Sdcb Chats 技术博客:数据库 ID 选型的曲折之路 - 从 Guid 到自增 ID,再到
· 语音处理 开源项目 EchoSharp
· 《HelloGitHub》第 106 期
· Spring AI + Ollama 实现 deepseek-r1 的API服务和调用
· 使用 Dify + LLM 构建精确任务处理应用
点击右上角即可分享
微信分享提示