每日一题 为了工作 2020 0507 第六十五题
package data.bysj.tree;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;
import java.util.HashMap;
import java.util.Map;
/**
*
* @author 雪瞳
* @Slogan 时钟尚且前行,人怎能就此止步!
* @Function
*
*/
public class RandomForestTrees {
public static void main(String[] args) {
String name = "forest";
String master = "local[3]";
SparkSession session = SparkSession.builder().master(master).appName(name).getOrCreate();
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(session.sparkContext());
jsc.setLogLevel("Error");
JavaRDD<String> input = jsc.textFile("./save/rootData");
JavaRDD<LabeledPoint> metaData = input.map(new Function<String, LabeledPoint>() {
@Override
public LabeledPoint call(String line) throws Exception {
//"2015-11-01 20:20:16" 1.85999330468501 1.22359452534749 2.51578969727773 -0.403918740333512 0.0149184125297424 0
String[] splits = line.split("\t");
String label = splits[splits.length - 1];
double[] wd = new double[splits.length - 3];
for (int i = 0; i < wd.length; i++) {
wd[i] = Double.parseDouble(splits[i + 1]);
}
LabeledPoint labeledPoint = new LabeledPoint(Double.parseDouble(label), Vectors.dense(wd));
return labeledPoint;
}
});
double[] doubles = new double[]{0.7,0.3};
//RDD<LabeledPoint> rdd = metaData.rdd();
JavaRDD<LabeledPoint>[] metaDataSource = metaData.randomSplit(doubles, 10L);
JavaRDD<LabeledPoint> traingData = metaDataSource[0];
JavaRDD<LabeledPoint> testData = metaDataSource[1];
int numClass = 2;
Map<Integer,Integer> categoricalFeaturesInfo = new HashMap<>();
int numTrees = 3;
String featureSubsetStrategy = "auto";
String impurity = "entropy";
int maxDepth = 4;
int maxBins = 32;
int seed = 1;
RandomForestModel model = RandomForest.trainClassifier(
traingData,
numClass,
categoricalFeaturesInfo,
numTrees,
featureSubsetStrategy,
impurity,
maxDepth,
maxBins,
seed
);
JavaRDD<Double> predictRdd = testData.map(new Function<LabeledPoint, Double>() {
@Override
public Double call(LabeledPoint labeledPoint) throws Exception {
double predict = model.predict(labeledPoint.features());
return predict;
}
});
JavaPairRDD<Double, Double> resultRDD = predictRdd.zip(testData.map(new Function<LabeledPoint, Double>() {
@Override
public Double call(LabeledPoint labeledPoint) throws Exception {
return labeledPoint.label();
}
}));
long count = resultRDD.count();
Accumulator<Integer> accumulator = jsc.accumulator(0);
resultRDD.foreach(new VoidFunction<Tuple2<Double, Double>>() {
@Override
public void call(Tuple2<Double, Double> tp) throws Exception {
Double label = tp._2();
Double predict = tp._1();
if (Double.compare(label,predict)==0){
accumulator.add(1);
}
}
});
Integer value = accumulator.value();
System.err.println("数目是:"+count);
System.err.println("数目是:"+value);
double rate = value / (double) count;
System.err.println("正确率是:"+rate*100+"%");
String path ="./save/model";
double stand = 80.00;
if (Double.compare(rate,stand)<0){
// model.save(sc,path);
System.out.println(model.toDebugString());
}
}
}