DecisionTree模型
-
连接远程Anaconda3
-
查看
-
案例1
import findspark findspark.init() ############################################## from pyspark.sql import SparkSession from pyspark.sql.context import SQLContext from pyspark.ml.feature import StringIndexer, VectorAssembler spark = SparkSession.builder \ .master("local[*]") \ .appName("PySpark ML") \ .getOrCreate() sc = spark.sparkContext # assengerId(乘客ID)、Pclass(船舱等级)、Name(姓名)、Sex(性别)、 # Age(年龄)、SibSp(同行的兄弟姐妹/配偶数)、Parch(同行的父母/子女数)、Ticket(船票号码)、Fare(船票价格)、 # Cabin(客舱号码)、Embarked(登船港口 在哪个站上船的) 、Survived(是否幸存) ############################################# df_train = spark.read.csv('./titanic-train.csv',header=True,inferSchema=True) \ .cache() df_train = df_train.fillna({'Age': round(29.699,0)}) df_train = df_train.fillna({'Embarked': 'S'}) labelIndexer = StringIndexer(inputCol="Embarked", outputCol="iEmbarked") # 这行代码使用了 PySpark MLlib 中的 StringIndexer 转换器,用于将字符串类型的列转换为数值类型的列。具体来说, # 它将 "Embarked" 列的字符串值转换成一个数值,然后将结果存储在新的列 "iEmbarked" 中。 model = labelIndexer.fit(df_train) df_train = model.transform(df_train) labelIndexer = StringIndexer(inputCol="Sex", outputCol="iSex") df_train.show(20) model = labelIndexer.fit(df_train) df_train = model.transform(df_train) features = ['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked','Survived'] train_features = df_train[features] # 因为大多数机器学习算法都希望输入是一个特征向量/ df_assembler = VectorAssembler(inputCols=['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked'], outputCol="features") train = df_assembler.transform(train_features) from pyspark.ml.classification import DecisionTreeClassifier #DecisionTree模型 # labelCol="Survived": 这个参数指定了标签列,也就是模型要预测的目标列。在这里,标签列被命名为 "Survived",可能表示乘客是否幸存。 # featuresCol="features": 这个参数指定了特征列,也就是模型将使用的输入特征。在这里,特征列被命名为 "features",通常是一个包含多个特征的向量列,例如通过 VectorAssembler 合并得到的特征向量。 dtree = DecisionTreeClassifier(labelCol="Survived", featuresCol="features") treeModel = dtree.fit(train) #打印treeModel # 详细了解决策树的构建和分割过程。 print(treeModel.toDebugString) #对训练数据进行预测 dt_predictions = treeModel.transform(train) # 预测值 真实值 dt_predictions.select("prediction", "Survived", "features").show(truncate=True) from pyspark.ml.evaluation import MulticlassClassificationEvaluator multi_evaluator = MulticlassClassificationEvaluator(labelCol = 'Survived', metricName = 'accuracy') # roc 曲线面积值 0~1 之间 越接近1 模型效果越好 print('Decision Tree Accu:', multi_evaluator.evaluate(dt_predictions)) ############################################# # 创建一个包含待预测特征的 DataFrame new_data = [(3, 1.0, 30.0, 0, 0, 10.0, 2.0)] # 举例的新数据,具体特征值需要根据你的实际情况填写 new_df = spark.createDataFrame(new_data, ["Pclass", "iSex", "Age", "SibSp", "Parch", "Fare", "iEmbarked"]) # 使用 VectorAssembler 创建特征向量列 new_features = df_assembler.transform(new_df) # 使用训练好的决策树模型进行预测 predictions = treeModel.transform(new_features) # 查看预测结果 print("查看预测结果" ) predictions.select("prediction", "features").show() sc.stop()
- 控制台打印
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+---------+ |PassengerId|Survived|Pclass| Name| Sex| Age|SibSp|Parch| Ticket| Fare|Cabin|Embarked|iEmbarked| +-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+---------+ | 1| 0| 3|Braund, Mr. Owen ...| male|22.0| 1| 0| A/5 21171| 7.25| null| S| 0.0| | 2| 1| 1|Cumings, Mrs. Joh...|female|38.0| 1| 0| PC 17599|71.2833| C85| C| 1.0| | 3| 1| 3|Heikkinen, Miss. ...|female|26.0| 0| 0|STON/O2. 3101282| 7.925| null| S| 0.0| | 4| 1| 1|Futrelle, Mrs. Ja...|female|35.0| 1| 0| 113803| 53.1| C123| S| 0.0| | 5| 0| 3|Allen, Mr. Willia...| male|35.0| 0| 0| 373450| 8.05| null| S| 0.0| | 6| 0| 3| Moran, Mr. James| male|30.0| 0| 0| 330877| 8.4583| null| Q| 2.0| | 7| 0| 1|McCarthy, Mr. Tim...| male|54.0| 0| 0| 17463|51.8625| E46| S| 0.0| | 8| 0| 3|Palsson, Master. ...| male| 2.0| 3| 1| 349909| 21.075| null| S| 0.0| | 9| 1| 3|Johnson, Mrs. Osc...|female|27.0| 0| 2| 347742|11.1333| null| S| 0.0| | 10| 1| 2|Nasser, Mrs. Nich...|female|14.0| 1| 0| 237736|30.0708| null| C| 1.0| | 11| 1| 3|Sandstrom, Miss. ...|female| 4.0| 1| 1| PP 9549| 16.7| G6| S| 0.0| | 12| 1| 1|Bonnell, Miss. El...|female|58.0| 0| 0| 113783| 26.55| C103| S| 0.0| | 13| 0| 3|Saundercock, Mr. ...| male|20.0| 0| 0| A/5. 2151| 8.05| null| S| 0.0| | 14| 0| 3|Andersson, Mr. An...| male|39.0| 1| 5| 347082| 31.275| null| S| 0.0| | 15| 0| 3|Vestrom, Miss. Hu...|female|14.0| 0| 0| 350406| 7.8542| null| S| 0.0| | 16| 1| 2|Hewlett, Mrs. (Ma...|female|55.0| 0| 0| 248706| 16.0| null| S| 0.0| | 17| 0| 3|Rice, Master. Eugene| male| 2.0| 4| 1| 382652| 29.125| null| Q| 2.0| | 18| 1| 2|Williams, Mr. Cha...| male|30.0| 0| 0| 244373| 13.0| null| S| 0.0| | 19| 0| 3|Vander Planke, Mr...|female|31.0| 1| 0| 345763| 18.0| null| S| 0.0| | 20| 1| 3|Masselmani, Mrs. ...|female|30.0| 0| 0| 2649| 7.225| null| C| 1.0| +-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+---------+ only showing top 20 rows DecisionTreeClassificationModel (uid=DecisionTreeClassifier_359a22d7dd3d) of depth 5 with 35 nodes If (feature 1 in {0.0}) If (feature 2 <= 3.5) If (feature 3 <= 2.5) Predict: 1.0 Else (feature 3 > 2.5) If (feature 4 <= 1.5) Predict: 0.0 Else (feature 4 > 1.5) If (feature 3 <= 4.5) Predict: 1.0 Else (feature 3 > 4.5) Predict: 0.0 Else (feature 2 > 3.5) If (feature 0 <= 1.5) If (feature 5 <= 26.125) Predict: 0.0 Else (feature 5 > 26.125) If (feature 5 <= 26.46875) Predict: 1.0 Else (feature 5 > 26.46875) Predict: 0.0 Else (feature 0 > 1.5) If (feature 2 <= 15.5) If (feature 3 <= 1.5) Predict: 1.0 Else (feature 3 > 1.5) Predict: 0.0 Else (feature 2 > 15.5) Predict: 0.0 Else (feature 1 not in {0.0}) If (feature 0 <= 2.5) If (feature 2 <= 3.5) If (feature 0 <= 1.5) Predict: 0.0 Else (feature 0 > 1.5) Predict: 1.0 Else (feature 2 > 3.5) Predict: 1.0 Else (feature 0 > 2.5) If (feature 5 <= 24.808349999999997) If (feature 6 in {1.0,2.0}) If (feature 2 <= 30.25) Predict: 1.0 Else (feature 2 > 30.25) Predict: 0.0 Else (feature 6 not in {1.0,2.0}) If (feature 5 <= 21.0375) Predict: 1.0 Else (feature 5 > 21.0375) Predict: 0.0 Else (feature 5 > 24.808349999999997) Predict: 0.0 +----------+--------+--------------------+ |prediction|Survived| features| +----------+--------+--------------------+ | 0.0| 0|[3.0,0.0,22.0,1.0...| | 1.0| 1|[1.0,1.0,38.0,1.0...| | 1.0| 1|[3.0,1.0,26.0,0.0...| | 1.0| 1|[1.0,1.0,35.0,1.0...| | 0.0| 0|(7,[0,2,5],[3.0,3...| | 0.0| 0|[3.0,0.0,30.0,0.0...| | 0.0| 0|(7,[0,2,5],[1.0,5...| | 0.0| 0|[3.0,0.0,2.0,3.0,...| | 1.0| 1|[3.0,1.0,27.0,0.0...| | 1.0| 1|[2.0,1.0,14.0,1.0...| | 1.0| 1|[3.0,1.0,4.0,1.0,...| | 1.0| 1|[1.0,1.0,58.0,0.0...| | 0.0| 0|(7,[0,2,5],[3.0,2...| | 0.0| 0|[3.0,0.0,39.0,1.0...| | 1.0| 0|[3.0,1.0,14.0,0.0...| | 1.0| 1|[2.0,1.0,55.0,0.0...| | 0.0| 0|[3.0,0.0,2.0,4.0,...| | 0.0| 1|(7,[0,2,5],[2.0,3...| | 1.0| 0|[3.0,1.0,31.0,1.0...| | 1.0| 1|[3.0,1.0,30.0,0.0...| +----------+--------+--------------------+ only showing top 20 rows Decision Tree Accu: 0.8417508417508418 查看预测结果 +----------+--------------------+ |prediction| features| +----------+--------------------+ | 1.0|[3.0,1.0,30.0,0.0...| +----------+--------------------+
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
· 字符编码:从基础到乱码解决
· 提示词工程——AI应用必不可少的技术