本地测试Spark的svm算法
上一篇介绍了逻辑回归算法,发现分类效果不好,通过这次的svm发现是因为训练数据不行,于是网上找了部分训练数据,发现实际上分类效果还可以。
训练数据,第一个值是标签,下面的数据是某种花的相关特征。
1|5.1,3.5,1.4,0.2 1|4.9,3,1.4,0.2 1|4.7,3.2,1.3,0.2 1|4.6,3.1,1.5,0.2 1|5,3.6,1.4,0.2 1|5.4,3.9,1.7,0.4 1|4.6,3.4,1.4,0.3 1|5,3.4,1.5,0.2 1|4.4,2.9,1.4,0.2 1|4.9,3.1,1.5,0.1 1|5.4,3.7,1.5,0.2 1|4.8,3.4,1.6,0.2 1|4.8,3,1.4,0.1 1|4.3,3,1.1,0.1 1|5.8,4,1.2,0.2 1|5.7,4.4,1.5,0.4 1|5.4,3.9,1.3,0.4 1|5.1,3.5,1.4,0.3 1|5.7,3.8,1.7,0.3 1|5.1,3.8,1.5,0.3 1|5.4,3.4,1.7,0.2 1|5.1,3.7,1.5,0.4 1|4.6,3.6,1,0.2 1|5.1,3.3,1.7,0.5 1|4.8,3.4,1.9,0.2 0|7,3.2,4.7,1.4 0|6.4,3.2,4.5,1.5 0|6.9,3.1,4.9,1.5 0|5.5,2.3,4,1.3 0|6.5,2.8,4.6,1.5 0|5.7,2.8,4.5,1.3 0|6.3,3.3,4.7,1.6 0|4.9,2.4,3.3,1 0|6.6,2.9,4.6,1.3 0|5.2,2.7,3.9,1.4 0|5,2,3.5,1 0|5.9,3,4.2,1.5 0|6,2.2,4,1 0|6.1,2.9,4.7,1.4 0|5.6,2.9,3.6,1.3
测试数据如下。
0|5.1,2.5,3,1.1 0|5.7,2.8,4.1,1.3 1|5,3,1.6,0.2 1|5,3.4,1.6,0.4
svm代码跟逻辑回归类似,只需替换算法即可。
import org.apache.spark.mllib.classification.SVMWithSGD import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.optimization.L1Updater import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.{SparkConf, SparkContext} object TestSvmAlgorithm { def main(args: Array[String]): Unit = { val sparkConf = new SparkConf().setAppName("svm").setMaster("local").set("spark.testing.memory", "2147480000") val sparkContext = new SparkContext(sparkConf) val dataSpark = sparkContext.textFile("file:///D:\\var\\11.txt") val trainData = dataSpark.map(line => { val tmpLine = line.split("\\|") println("数据:" + tmpLine(0).toDouble, Vectors.dense(tmpLine(1).split("\\,").map(_.toDouble))) LabeledPoint(tmpLine(0).toDouble, Vectors.dense(tmpLine(1).split("\\,").map(_.toDouble))) }).cache() val iterationNum = 20 // val model = SVMWithSGD.train(trainData, iterationNum) val svmModel = new SVMWithSGD() svmModel.optimizer.setNumIterations(10).setRegParam(0.1).setUpdater(new L1Updater()) val model = svmModel.run(trainData) val predictData = Vectors.dense(6.6,3,4.4,1.4) println(predictData) val result = model.predict(predictData) println(result) val labelAndPredicts = trainData.map(p => { val predi = model.predict(p.features) println("预测" + (p.label, predi)) (p.label, predi) }) val mericTest = new BinaryClassificationMetrics(labelAndPredicts) val auRoc = mericTest.areaUnderROC() println(":" + auRoc) } }