DecisionTree模型
-
连接远程Anaconda3
-
查看
-
案例1
import findspark
findspark.init()
##############################################
from pyspark.sql import SparkSession
from pyspark.sql.context import SQLContext
from pyspark.ml.feature import StringIndexer, VectorAssembler
spark = SparkSession.builder \
.master("local[*]") \
.appName("PySpark ML") \
.getOrCreate()
sc = spark.sparkContext
# assengerId(乘客ID)、Pclass(船舱等级)、Name(姓名)、Sex(性别)、
# Age(年龄)、SibSp(同行的兄弟姐妹/配偶数)、Parch(同行的父母/子女数)、Ticket(船票号码)、Fare(船票价格)、
# Cabin(客舱号码)、Embarked(登船港口 在哪个站上船的) 、Survived(是否幸存)
#############################################
df_train = spark.read.csv('./titanic-train.csv',header=True,inferSchema=True) \
.cache()
df_train = df_train.fillna({'Age': round(29.699,0)})
df_train = df_train.fillna({'Embarked': 'S'})
labelIndexer = StringIndexer(inputCol="Embarked", outputCol="iEmbarked")
# 这行代码使用了 PySpark MLlib 中的 StringIndexer 转换器,用于将字符串类型的列转换为数值类型的列。具体来说,
# 它将 "Embarked" 列的字符串值转换成一个数值,然后将结果存储在新的列 "iEmbarked" 中。
model = labelIndexer.fit(df_train)
df_train = model.transform(df_train)
labelIndexer = StringIndexer(inputCol="Sex", outputCol="iSex")
df_train.show(20)
model = labelIndexer.fit(df_train)
df_train = model.transform(df_train)
features = ['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked','Survived']
train_features = df_train[features]
# 因为大多数机器学习算法都希望输入是一个特征向量/
df_assembler = VectorAssembler(inputCols=['Pclass', 'iSex', 'Age', 'SibSp',
'Parch', 'Fare', 'iEmbarked'], outputCol="features")
train = df_assembler.transform(train_features)
from pyspark.ml.classification import DecisionTreeClassifier
#DecisionTree模型
# labelCol="Survived": 这个参数指定了标签列,也就是模型要预测的目标列。在这里,标签列被命名为 "Survived",可能表示乘客是否幸存。
# featuresCol="features": 这个参数指定了特征列,也就是模型将使用的输入特征。在这里,特征列被命名为 "features",通常是一个包含多个特征的向量列,例如通过 VectorAssembler 合并得到的特征向量。
dtree = DecisionTreeClassifier(labelCol="Survived", featuresCol="features")
treeModel = dtree.fit(train)
#打印treeModel
# 详细了解决策树的构建和分割过程。
print(treeModel.toDebugString)
#对训练数据进行预测
dt_predictions = treeModel.transform(train)
# 预测值 真实值
dt_predictions.select("prediction", "Survived", "features").show(truncate=True)
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
multi_evaluator = MulticlassClassificationEvaluator(labelCol = 'Survived', metricName = 'accuracy')
# roc 曲线面积值 0~1 之间 越接近1 模型效果越好
print('Decision Tree Accu:', multi_evaluator.evaluate(dt_predictions))
#############################################
# 创建一个包含待预测特征的 DataFrame
new_data = [(3, 1.0, 30.0, 0, 0, 10.0, 2.0)] # 举例的新数据,具体特征值需要根据你的实际情况填写
new_df = spark.createDataFrame(new_data, ["Pclass", "iSex", "Age", "SibSp", "Parch", "Fare", "iEmbarked"])
# 使用 VectorAssembler 创建特征向量列
new_features = df_assembler.transform(new_df)
# 使用训练好的决策树模型进行预测
predictions = treeModel.transform(new_features)
# 查看预测结果
print("查看预测结果" )
predictions.select("prediction", "features").show()
sc.stop()
- 控制台打印
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+---------+
|PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|iEmbarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+---------+
| 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| null| S| 0.0|
| 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C| 1.0|
| 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| null| S| 0.0|
| 4| 1| 1|Futrelle, Mrs. Ja...|female|35.0| 1| 0| 113803| 53.1| C123| S| 0.0|
| 5| 0| 3|Allen, Mr. Willia...| male|35.0| 0| 0| 373450| 8.05| null| S| 0.0|
| 6| 0| 3| Moran, Mr. James| male|30.0| 0| 0| 330877| 8.4583| null| Q| 2.0|
| 7| 0| 1|McCarthy, Mr. Tim...| male|54.0| 0| 0| 17463|51.8625| E46| S| 0.0|
| 8| 0| 3|Palsson, Master. ...| male| 2.0| 3| 1| 349909| 21.075| null| S| 0.0|
| 9| 1| 3|Johnson, Mrs. Osc...|female|27.0| 0| 2| 347742|11.1333| null| S| 0.0|
| 10| 1| 2|Nasser, Mrs. Nich...|female|14.0| 1| 0| 237736|30.0708| null| C| 1.0|
| 11| 1| 3|Sandstrom, Miss. ...|female| 4.0| 1| 1| PP 9549| 16.7| G6| S| 0.0|
| 12| 1| 1|Bonnell, Miss. El...|female|58.0| 0| 0| 113783| 26.55| C103| S| 0.0|
| 13| 0| 3|Saundercock, Mr. ...| male|20.0| 0| 0| A/5. 2151| 8.05| null| S| 0.0|
| 14| 0| 3|Andersson, Mr. An...| male|39.0| 1| 5| 347082| 31.275| null| S| 0.0|
| 15| 0| 3|Vestrom, Miss. Hu...|female|14.0| 0| 0| 350406| 7.8542| null| S| 0.0|
| 16| 1| 2|Hewlett, Mrs. (Ma...|female|55.0| 0| 0| 248706| 16.0| null| S| 0.0|
| 17| 0| 3|Rice, Master. Eugene| male| 2.0| 4| 1| 382652| 29.125| null| Q| 2.0|
| 18| 1| 2|Williams, Mr. Cha...| male|30.0| 0| 0| 244373| 13.0| null| S| 0.0|
| 19| 0| 3|Vander Planke, Mr...|female|31.0| 1| 0| 345763| 18.0| null| S| 0.0|
| 20| 1| 3|Masselmani, Mrs. ...|female|30.0| 0| 0| 2649| 7.225| null| C| 1.0|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+---------+
only showing top 20 rows
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_359a22d7dd3d) of depth 5 with 35 nodes
If (feature 1 in {0.0})
If (feature 2 <= 3.5)
If (feature 3 <= 2.5)
Predict: 1.0
Else (feature 3 > 2.5)
If (feature 4 <= 1.5)
Predict: 0.0
Else (feature 4 > 1.5)
If (feature 3 <= 4.5)
Predict: 1.0
Else (feature 3 > 4.5)
Predict: 0.0
Else (feature 2 > 3.5)
If (feature 0 <= 1.5)
If (feature 5 <= 26.125)
Predict: 0.0
Else (feature 5 > 26.125)
If (feature 5 <= 26.46875)
Predict: 1.0
Else (feature 5 > 26.46875)
Predict: 0.0
Else (feature 0 > 1.5)
If (feature 2 <= 15.5)
If (feature 3 <= 1.5)
Predict: 1.0
Else (feature 3 > 1.5)
Predict: 0.0
Else (feature 2 > 15.5)
Predict: 0.0
Else (feature 1 not in {0.0})
If (feature 0 <= 2.5)
If (feature 2 <= 3.5)
If (feature 0 <= 1.5)
Predict: 0.0
Else (feature 0 > 1.5)
Predict: 1.0
Else (feature 2 > 3.5)
Predict: 1.0
Else (feature 0 > 2.5)
If (feature 5 <= 24.808349999999997)
If (feature 6 in {1.0,2.0})
If (feature 2 <= 30.25)
Predict: 1.0
Else (feature 2 > 30.25)
Predict: 0.0
Else (feature 6 not in {1.0,2.0})
If (feature 5 <= 21.0375)
Predict: 1.0
Else (feature 5 > 21.0375)
Predict: 0.0
Else (feature 5 > 24.808349999999997)
Predict: 0.0
+----------+--------+--------------------+
|prediction|Survived| features|
+----------+--------+--------------------+
| 0.0| 0|[3.0,0.0,22.0,1.0...|
| 1.0| 1|[1.0,1.0,38.0,1.0...|
| 1.0| 1|[3.0,1.0,26.0,0.0...|
| 1.0| 1|[1.0,1.0,35.0,1.0...|
| 0.0| 0|(7,[0,2,5],[3.0,3...|
| 0.0| 0|[3.0,0.0,30.0,0.0...|
| 0.0| 0|(7,[0,2,5],[1.0,5...|
| 0.0| 0|[3.0,0.0,2.0,3.0,...|
| 1.0| 1|[3.0,1.0,27.0,0.0...|
| 1.0| 1|[2.0,1.0,14.0,1.0...|
| 1.0| 1|[3.0,1.0,4.0,1.0,...|
| 1.0| 1|[1.0,1.0,58.0,0.0...|
| 0.0| 0|(7,[0,2,5],[3.0,2...|
| 0.0| 0|[3.0,0.0,39.0,1.0...|
| 1.0| 0|[3.0,1.0,14.0,0.0...|
| 1.0| 1|[2.0,1.0,55.0,0.0...|
| 0.0| 0|[3.0,0.0,2.0,4.0,...|
| 0.0| 1|(7,[0,2,5],[2.0,3...|
| 1.0| 0|[3.0,1.0,31.0,1.0...|
| 1.0| 1|[3.0,1.0,30.0,0.0...|
+----------+--------+--------------------+
only showing top 20 rows
Decision Tree Accu: 0.8417508417508418
查看预测结果
+----------+--------------------+
|prediction| features|
+----------+--------------------+
| 1.0|[3.0,1.0,30.0,0.0...|
+----------+--------------------+