基于RDD的决策树的Spark API
概要
决策树及其集合是分类和回归的机器学习任务的流行方法。决策树被广泛使用,因为它们易于解释,可以处理分类特征,扩展到多类分类环境,不需要特征缩放,并且能够捕捉非线性和特征的相互作用。树的组合算法,如随机森林和提升算法,是分类和回归任务中表现最好的。
在这里就不过多介绍决策树的具体原理了,这篇文章主要介绍Spark的决策树的API。下面,主要讲解了Spark的决策树的参数调节的技巧:
使用技巧
我们在讲解各使用决策树的技巧的时候,先表明一下决策树不在创建新的节点的规则:
暂停规则#
- 当前递归节点已经为训练的最大深度参数(maxDepth)。
- 所有选择的特征来划分的数据集计算出的信息增益都小于最小增益参数(minInfoGain)
- 没有划分的候选节点(该集合的标签都相同)
上面讲解了停止规则,下面提供了一些使用决策树的指南,下面按照重要程度进行降序排列。
下面这些参数描述了你要解决的问题和数据集。它们应该被指定,不需要调整。
- algo: 决策树的类型:分类,回归。
- numClasses:分类的数量
- categoricalFeaturesInfo:指定哪些特征是分类的,以及这些特征中的每一个可以取多少个分类值。这是一个从特征指数到特征 arity(类别的数量)的映射。任何不在此图中的特征都被视为连续的。
可调参数#
这些参数可以进行调整。在调谐时要注意在保留的测试数据上进行验证,以避免过度拟合。
分类#
package com.ForestTest.com
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
object Main {
def main(args: Array[String]): Unit = {
//TODO 创建环境
val conf: SparkConf = new SparkConf().setAppName("AppName").setMaster("local[*]")
val sc = new SparkContext(conf)
//TODO 数据操作
//加载数据
val path = "src/main/resources/data/mllib/sample_libsvm_data.txt"
val data: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, path)
val dataArr: Array[RDD[LabeledPoint]] = data.randomSplit(Array(0.7, 0.3))
val (trainData, testData): (RDD[LabeledPoint], RDD[LabeledPoint]) = (dataArr(0), dataArr(1))
//训练模型
val numClasses = 2
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]() //如果为空则表明是连续值,如果存在则为离散值
val numTrees = 3
val featureSubsetStrategy = "auto"
val impurity = "gini"
val maxDepth = 4
val maxBins = 32
val model: RandomForestModel = RandomForest.trainClassifier(trainData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy,
impurity, maxDepth, maxBins)
val labelAndPreds: RDD[(Double, Double)] = testData.map(point => {
val prediction: Double = model.predict(point.features)
(point.label, prediction)
})
val testErr: Double = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
//输出
println(s"Test Error = $testErr")
println(s"Learned classification forest model:\n ${model.toDebugString}")
//TODO 关闭环境
}
}
回归#
package com.RegressionForestTest
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
object Main {
def main(args: Array[String]): Unit = {
//TODO 创建环境
val conf: SparkConf = new SparkConf().setAppName("test").setMaster("local[*]")
val sc = new SparkContext(conf)
//TODO 数据操作
//读取数据
val path = "src/main/resources/data/mllib/sample_libsvm_data.txt"
val data: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, path)
val dataArr: Array[RDD[LabeledPoint]] = data.randomSplit(Array(0.7, 0.3))
//训练模型
val numClasses = 2 //分类的数量
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()
val numTrees = 3
val featureSubsetStrategy = "auto"
val impurity = "variance" //使用方差
val maxDepth = 4
val maxBins = 32
val model: RandomForestModel = RandomForest.trainRegressor(dataArr(0), categoricalFeaturesInfo, numTrees,
featureSubsetStrategy, impurity, maxDepth, maxBins)
val labelAndPredictions: RDD[(Double, Double)] = dataArr(1).map(point => {
val prediction: Double = model.predict(point.features)
(point.label, prediction)
})
val testMSE: Double = labelAndPredictions.map(r => math.pow((r._1 - r._2), 2)).mean()
println(s"Test Mean Squared Error = ${testMSE}")
println(s"Learned regression forest model:\n ${model.toDebugString}")
//TODO 关闭环境
sc.stop()
}
}
GBT#
package com.GBTtest
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
object Main {
def main(args: Array[String]): Unit = {
//TODO 创建环境
val conf: SparkConf = new SparkConf().setAppName("test").setMaster("local[*]")
val sc = new SparkContext(conf)
//TODO 数据操作
val path = "src/main/resources/data/mllib/sample_libsvm_data.txt"
val data: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, path)
val dataArr: Array[RDD[LabeledPoint]] = data.randomSplit(Array(0.7, 0.3))
//训练模型
val boostingStrategy: BoostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations=3
boostingStrategy.treeStrategy.numClasses=2
boostingStrategy.treeStrategy.maxDepth=5
//如果为空则为连续值
boostingStrategy.treeStrategy.categoricalFeaturesInfo=Map[Int, Int]()
val model: GradientBoostedTreesModel = GradientBoostedTrees.train(dataArr(0), boostingStrategy)
//计算出预测
val labelAndPreds: RDD[(Double, Double)] = dataArr(1).map(point => {
val prediction: Double = model.predict(point.features)
(point.label, prediction)
})
val testErr: Double = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / dataArr(1).count()
println(s"Test Error =${testErr}")
println(s"Learned classification GBT model:\n ${model.toDebugString}")
//TODO 关闭环境
sc.stop()
}
}
提升树(分类)#
package com.GradientBoostTest
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
object Main {
def main(args: Array[String]): Unit = {
//TODO 创建环境
val conf: SparkConf = new SparkConf().setAppName("test").setMaster("local[*]")
val sc = new SparkContext(conf)
//TODO 数据操作
val path = "src/main/resources/data/mllib/sample_libsvm_data.txt"
val data: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, path)
val dataArr: Array[RDD[LabeledPoint]] = data.randomSplit(Array(0.7, 0.3))
//设置参数
val boostingStrategy: BoostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations=3
boostingStrategy.treeStrategy.numClasses=2
boostingStrategy.treeStrategy.maxDepth=5
boostingStrategy.treeStrategy.categoricalFeaturesInfo=Map[Int, Int]()
val model: GradientBoostedTreesModel = GradientBoostedTrees.train(dataArr(0), boostingStrategy)
val labelAndPreds: RDD[(Double, Double)] = dataArr(1).map(point => {
val prediction: Double = model.predict(point.features)
(point.label, prediction)
})
val testRErr: Double = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / dataArr(1).count()
println(s"Test Error = $testRErr")
println(s"Learned classification GBT model:\n ${model.toDebugString}")
//TODO 关闭环境
sc.stop()
}
}
提升树(回归)#
package com.GradientBoost
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
object Main {
def main(args: Array[String]): Unit = {
//TODO 创建环境
val conf: SparkConf = new SparkConf().setAppName("test").setMaster("local[*]")
val sc = new SparkContext(conf)
//TODO 数据操作
//读取数据
val path = "src/main/resources/data/mllib/sample_libsvm_data.txt"
val data: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, path)
val dataArr: Array[RDD[LabeledPoint]] = data.randomSplit(Array(0.7, 0.3))
val boostingStrategy: BoostingStrategy = BoostingStrategy.defaultParams("Regression")
boostingStrategy.numIterations=3
boostingStrategy.treeStrategy.maxDepth=5
boostingStrategy.treeStrategy.categoricalFeaturesInfo=Map[Int, Int]()
val model: GradientBoostedTreesModel = GradientBoostedTrees.train(dataArr(0), boostingStrategy)
//预测结果
val labelsAndPrediction: RDD[(Double, Double)] = dataArr(1).map(point => {
val prediction: Double = model.predict(point.features)
(point.label, prediction)
})
//计算误差
val testMSE: Double = labelsAndPrediction.map(r => {
math.pow((r._1 - r._2), 2)
}).mean()
println(s"Test mean Squared Error = ${testMSE}")
println(s"Learned regression GBT model:\n ${model.toDebugString}")
//TODO 关闭环境
sc.stop()
}
}
作者:ALINGMAOMAO
出处:https://www.cnblogs.com/ALINGMAOMAO/p/17118147.html
版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。
标签:
Spark
, SparkMLlib指南
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探