逻辑斯蒂回归分类器(Logistic Regression)

from pyspark.ml.linalg import Vector,Vectors
from pyspark.sql import Row,functions
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from pyspark.ml.feature import IndexToString, StringIndexer,VectorIndexer,HashingTF, Tokenizer
from pyspark.ml.classification import LogisticRegression,LogisticRegressionModel,BinaryLogisticRegressionSummary,LogisticRegression
from pyspark.ml.classification import DecisionTreeClassificationModel
from pyspark.ml.classification import DecisionTreeClassifier

data = sc.textFile("file:///home/hw17685187119/student2.txt").map(lambda line: line.split(';')).toDF()

def f(x):
rel = {}
rel['label'] = str(x[22])
rel['features']=Vectors. \
dense(str(x[2]),str(x[24]),str(x[28]),str(x[29]))
return rel

rdd=data.rdd.map(lambda p: str(p._1)+","+str(p._2)+","+str(p._3)+","+str(p._4)+","+str(p._5)+","+str(p._6)+","+str(p._7)+","+str(p._8)+","+str(p._9)+","+str(p._10)+","+str(p._11)+","+str(p._12)+","+str(p._13)+","+str(p._14)+","+str(p._15)+","+str(p._16)+","+str(p._17)+","+str(p._18)+","+str(p._19)+","+str(p._20)+","+str(p._21)+","+str(p._22)+","+str(p._23)+","+str(p._24)+","+str(p._25)+","+str(p._26)+","+str(p._27)+","+str(p._28)+","+str(p._29)+","+str(p._30)+","+str(p._31)+","+str(p._32)+","+str(p._33))
data = rdd.map(lambda line: line.split(',')).map(lambda p: Row(**f(p))).toDF()

labelIndexer = StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
indexed = labelIndexer.transform(data)
indexed.show(20)

dtClassifier = DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("features")
labelConverter = IndexToString(). setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
dtPipeline = Pipeline().setStages([labelIndexer, dtClassifier, labelConverter])
trainingData, testData = data.randomSplit([0.7, 0.3])
dtPipelineModel = dtPipeline.fit(trainingData)
dtPredictions = dtPipelineModel.transform(testData)
dtPredictions.select("predictedLabel", "label", "features").show(20)

evaluator = MulticlassClassificationEvaluator(). setLabelCol("indexedLabel").setPredictionCol("prediction")
dtAccuracy = evaluator.evaluate(dtPredictions)
dtAccuracy

posted @ 2021-06-16 13:45  Plum_Brilliant  阅读(119)  评论(0编辑  收藏  举报