展开
拓展 关闭
订阅号推广码
GitHub
视频
公告栏 关闭

DecisionTree模型

  • 连接远程Anaconda3

  • 查看

  • 案例1

import findspark
findspark.init()
##############################################
from pyspark.sql import SparkSession
from pyspark.sql.context import SQLContext
from pyspark.ml.feature import StringIndexer, VectorAssembler
spark = SparkSession.builder \
        .master("local[*]") \
        .appName("PySpark ML") \
        .getOrCreate()
sc = spark.sparkContext

# assengerId(乘客ID)、Pclass(船舱等级)、Name(姓名)、Sex(性别)、
# Age(年龄)、SibSp(同行的兄弟姐妹/配偶数)、Parch(同行的父母/子女数)、Ticket(船票号码)、Fare(船票价格)、
# Cabin(客舱号码)、Embarked(登船港口 在哪个站上船的) 、Survived(是否幸存)

#############################################
df_train = spark.read.csv('./titanic-train.csv',header=True,inferSchema=True) \
  .cache()
df_train = df_train.fillna({'Age': round(29.699,0)})
df_train = df_train.fillna({'Embarked': 'S'})
labelIndexer = StringIndexer(inputCol="Embarked", outputCol="iEmbarked")
# 这行代码使用了 PySpark MLlib 中的 StringIndexer 转换器,用于将字符串类型的列转换为数值类型的列。具体来说,
# 它将 "Embarked" 列的字符串值转换成一个数值,然后将结果存储在新的列 "iEmbarked" 中。
model = labelIndexer.fit(df_train)
df_train = model.transform(df_train)
labelIndexer = StringIndexer(inputCol="Sex", outputCol="iSex")

df_train.show(20)

model = labelIndexer.fit(df_train)
df_train = model.transform(df_train)
features = ['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked','Survived']
train_features = df_train[features]

# 因为大多数机器学习算法都希望输入是一个特征向量/
df_assembler = VectorAssembler(inputCols=['Pclass', 'iSex', 'Age', 'SibSp',
   'Parch', 'Fare', 'iEmbarked'], outputCol="features")
train = df_assembler.transform(train_features)

from pyspark.ml.classification import DecisionTreeClassifier
#DecisionTree模型
# labelCol="Survived": 这个参数指定了标签列,也就是模型要预测的目标列。在这里,标签列被命名为 "Survived",可能表示乘客是否幸存。
# featuresCol="features": 这个参数指定了特征列,也就是模型将使用的输入特征。在这里,特征列被命名为 "features",通常是一个包含多个特征的向量列,例如通过 VectorAssembler 合并得到的特征向量。
dtree = DecisionTreeClassifier(labelCol="Survived", featuresCol="features")
treeModel = dtree.fit(train)
#打印treeModel
# 详细了解决策树的构建和分割过程。
print(treeModel.toDebugString)
#对训练数据进行预测
dt_predictions = treeModel.transform(train)
# 预测值  真实值
dt_predictions.select("prediction", "Survived", "features").show(truncate=True)

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
multi_evaluator = MulticlassClassificationEvaluator(labelCol = 'Survived', metricName = 'accuracy')
# roc 曲线面积值  0~1 之间 越接近1 模型效果越好
print('Decision Tree Accu:', multi_evaluator.evaluate(dt_predictions))
#############################################

# 创建一个包含待预测特征的 DataFrame
new_data = [(3, 1.0, 30.0, 0, 0, 10.0, 2.0)]  # 举例的新数据,具体特征值需要根据你的实际情况填写
new_df = spark.createDataFrame(new_data, ["Pclass", "iSex", "Age", "SibSp", "Parch", "Fare", "iEmbarked"])

# 使用 VectorAssembler 创建特征向量列
new_features = df_assembler.transform(new_df)

# 使用训练好的决策树模型进行预测
predictions = treeModel.transform(new_features)

# 查看预测结果
print("查看预测结果"  )
predictions.select("prediction", "features").show()

sc.stop()
  • 控制台打印
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+---------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|iEmbarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+---------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|      0.0|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|      1.0|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|      0.0|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|      0.0|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|      0.0|
|          6|       0|     3|    Moran, Mr. James|  male|30.0|    0|    0|          330877| 8.4583| null|       Q|      2.0|
|          7|       0|     1|McCarthy, Mr. Tim...|  male|54.0|    0|    0|           17463|51.8625|  E46|       S|      0.0|
|          8|       0|     3|Palsson, Master. ...|  male| 2.0|    3|    1|          349909| 21.075| null|       S|      0.0|
|          9|       1|     3|Johnson, Mrs. Osc...|female|27.0|    0|    2|          347742|11.1333| null|       S|      0.0|
|         10|       1|     2|Nasser, Mrs. Nich...|female|14.0|    1|    0|          237736|30.0708| null|       C|      1.0|
|         11|       1|     3|Sandstrom, Miss. ...|female| 4.0|    1|    1|         PP 9549|   16.7|   G6|       S|      0.0|
|         12|       1|     1|Bonnell, Miss. El...|female|58.0|    0|    0|          113783|  26.55| C103|       S|      0.0|
|         13|       0|     3|Saundercock, Mr. ...|  male|20.0|    0|    0|       A/5. 2151|   8.05| null|       S|      0.0|
|         14|       0|     3|Andersson, Mr. An...|  male|39.0|    1|    5|          347082| 31.275| null|       S|      0.0|
|         15|       0|     3|Vestrom, Miss. Hu...|female|14.0|    0|    0|          350406| 7.8542| null|       S|      0.0|
|         16|       1|     2|Hewlett, Mrs. (Ma...|female|55.0|    0|    0|          248706|   16.0| null|       S|      0.0|
|         17|       0|     3|Rice, Master. Eugene|  male| 2.0|    4|    1|          382652| 29.125| null|       Q|      2.0|
|         18|       1|     2|Williams, Mr. Cha...|  male|30.0|    0|    0|          244373|   13.0| null|       S|      0.0|
|         19|       0|     3|Vander Planke, Mr...|female|31.0|    1|    0|          345763|   18.0| null|       S|      0.0|
|         20|       1|     3|Masselmani, Mrs. ...|female|30.0|    0|    0|            2649|  7.225| null|       C|      1.0|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+---------+
only showing top 20 rows

DecisionTreeClassificationModel (uid=DecisionTreeClassifier_359a22d7dd3d) of depth 5 with 35 nodes
  If (feature 1 in {0.0})
   If (feature 2 <= 3.5)
    If (feature 3 <= 2.5)
     Predict: 1.0
    Else (feature 3 > 2.5)
     If (feature 4 <= 1.5)
      Predict: 0.0
     Else (feature 4 > 1.5)
      If (feature 3 <= 4.5)
       Predict: 1.0
      Else (feature 3 > 4.5)
       Predict: 0.0
   Else (feature 2 > 3.5)
    If (feature 0 <= 1.5)
     If (feature 5 <= 26.125)
      Predict: 0.0
     Else (feature 5 > 26.125)
      If (feature 5 <= 26.46875)
       Predict: 1.0
      Else (feature 5 > 26.46875)
       Predict: 0.0
    Else (feature 0 > 1.5)
     If (feature 2 <= 15.5)
      If (feature 3 <= 1.5)
       Predict: 1.0
      Else (feature 3 > 1.5)
       Predict: 0.0
     Else (feature 2 > 15.5)
      Predict: 0.0
  Else (feature 1 not in {0.0})
   If (feature 0 <= 2.5)
    If (feature 2 <= 3.5)
     If (feature 0 <= 1.5)
      Predict: 0.0
     Else (feature 0 > 1.5)
      Predict: 1.0
    Else (feature 2 > 3.5)
     Predict: 1.0
   Else (feature 0 > 2.5)
    If (feature 5 <= 24.808349999999997)
     If (feature 6 in {1.0,2.0})
      If (feature 2 <= 30.25)
       Predict: 1.0
      Else (feature 2 > 30.25)
       Predict: 0.0
     Else (feature 6 not in {1.0,2.0})
      If (feature 5 <= 21.0375)
       Predict: 1.0
      Else (feature 5 > 21.0375)
       Predict: 0.0
    Else (feature 5 > 24.808349999999997)
     Predict: 0.0

+----------+--------+--------------------+
|prediction|Survived|            features|
+----------+--------+--------------------+
|       0.0|       0|[3.0,0.0,22.0,1.0...|
|       1.0|       1|[1.0,1.0,38.0,1.0...|
|       1.0|       1|[3.0,1.0,26.0,0.0...|
|       1.0|       1|[1.0,1.0,35.0,1.0...|
|       0.0|       0|(7,[0,2,5],[3.0,3...|
|       0.0|       0|[3.0,0.0,30.0,0.0...|
|       0.0|       0|(7,[0,2,5],[1.0,5...|
|       0.0|       0|[3.0,0.0,2.0,3.0,...|
|       1.0|       1|[3.0,1.0,27.0,0.0...|
|       1.0|       1|[2.0,1.0,14.0,1.0...|
|       1.0|       1|[3.0,1.0,4.0,1.0,...|
|       1.0|       1|[1.0,1.0,58.0,0.0...|
|       0.0|       0|(7,[0,2,5],[3.0,2...|
|       0.0|       0|[3.0,0.0,39.0,1.0...|
|       1.0|       0|[3.0,1.0,14.0,0.0...|
|       1.0|       1|[2.0,1.0,55.0,0.0...|
|       0.0|       0|[3.0,0.0,2.0,4.0,...|
|       0.0|       1|(7,[0,2,5],[2.0,3...|
|       1.0|       0|[3.0,1.0,31.0,1.0...|
|       1.0|       1|[3.0,1.0,30.0,0.0...|
+----------+--------+--------------------+
only showing top 20 rows

Decision Tree Accu: 0.8417508417508418
查看预测结果
+----------+--------------------+
|prediction|            features|
+----------+--------------------+
|       1.0|[3.0,1.0,30.0,0.0...|
+----------+--------------------+
posted @ 2024-01-09 21:27  DogLeftover  阅读(2)  评论(0编辑  收藏  举报