每日一题 为了工作 2020 0507 第六十五题

package data.bysj.tree;

import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;

import java.util.HashMap;
import java.util.Map;

/**
 *
 * @author 雪瞳
 * @Slogan 时钟尚且前行,人怎能就此止步!
 * @Function 
 *
 */
public class RandomForestTrees {
    public static void main(String[] args) {
        String name = "forest";
        String master = "local[3]";
        SparkSession session = SparkSession.builder().master(master).appName(name).getOrCreate();
        JavaSparkContext jsc = JavaSparkContext.fromSparkContext(session.sparkContext());
        jsc.setLogLevel("Error");
        JavaRDD<String> input = jsc.textFile("./save/rootData");
        JavaRDD<LabeledPoint> metaData = input.map(new Function<String, LabeledPoint>() {
            @Override
            public LabeledPoint call(String line) throws Exception {
                //"2015-11-01 20:20:16"	1.85999330468501	1.22359452534749	2.51578969727773	-0.403918740333512	0.0149184125297424		0
                String[] splits = line.split("\t");
                String label = splits[splits.length - 1];

                double[] wd = new double[splits.length - 3];
                for (int i = 0; i < wd.length; i++) {
                    wd[i] = Double.parseDouble(splits[i + 1]);
                }

                LabeledPoint labeledPoint = new LabeledPoint(Double.parseDouble(label), Vectors.dense(wd));
                return labeledPoint;
            }
        });
        
        double[] doubles = new double[]{0.7,0.3};
        //RDD<LabeledPoint> rdd = metaData.rdd();
        JavaRDD<LabeledPoint>[] metaDataSource = metaData.randomSplit(doubles, 10L);
        
        JavaRDD<LabeledPoint> traingData = metaDataSource[0];
        JavaRDD<LabeledPoint> testData = metaDataSource[1];
        
       
        int numClass = 2;
        
        Map<Integer,Integer> categoricalFeaturesInfo = new HashMap<>();
       
        int numTrees = 3;
        
        String featureSubsetStrategy = "auto";
       
        String impurity = "entropy";
        
        int maxDepth = 4;
      
        int maxBins = 32;
        
        int seed = 1;
        RandomForestModel model = RandomForest.trainClassifier(
                traingData,
                numClass,
                categoricalFeaturesInfo,
                numTrees,
                featureSubsetStrategy,
                impurity,
                maxDepth,
                maxBins,
                seed
                );
        JavaRDD<Double> predictRdd = testData.map(new Function<LabeledPoint, Double>() {
            @Override
            public Double call(LabeledPoint labeledPoint) throws Exception {
                double predict = model.predict(labeledPoint.features());
                return predict;
            }
        });
        JavaPairRDD<Double, Double> resultRDD = predictRdd.zip(testData.map(new Function<LabeledPoint, Double>() {
            @Override
            public Double call(LabeledPoint labeledPoint) throws Exception {
                return labeledPoint.label();
            }
        }));
        long count = resultRDD.count();
        Accumulator<Integer> accumulator = jsc.accumulator(0);
        resultRDD.foreach(new VoidFunction<Tuple2<Double, Double>>() {
            @Override
            public void call(Tuple2<Double, Double> tp) throws Exception {
                Double label = tp._2();
                Double predict = tp._1();
                if (Double.compare(label,predict)==0){
                    accumulator.add(1);
                }
            }
        });
        Integer value = accumulator.value();
        System.err.println("数目是:"+count);
        System.err.println("数目是:"+value);
        double rate = value / (double) count;
        System.err.println("正确率是:"+rate*100+"%");
        String path ="./save/model";
        double  stand = 80.00;
        if (Double.compare(rate,stand)<0){
//            model.save(sc,path);
            System.out.println(model.toDebugString());
        }

    }
}

  

 
posted @ 2020-05-07 16:34  雪瞳  阅读(124)  评论(0编辑  收藏  举报