Spark机器学习(1):线性回归算法

线性回归算法,是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。

1. 梯度下降法

线性回归可以使用最小二乘法,但是速度比较慢,因此一般使用梯度下降法(Gradient Descent),梯度下降法又分为批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent)。批量梯度下降法每次迭代需要使用训练集里面的所有数据,当训练集数据量较大时,速度就很慢;随机梯度下降法每次迭代只需要一个样本的数据,速度较快,对于大数据集,可能只需要使用少部分数据就达到收敛值,虽然有可能在最小值周围震荡,但是大多数情况下效果不错,所以,一般使用随机梯度下降法。

2. Mllib的线性回归

Mllib的线性回归采用的是随机梯度下降法。直接上代码:

import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors

object LinearRegression {

  def main(args: Array[String]) {
    // 设置运行环境
    val conf = new SparkConf().setAppName("Linear Regression Test").setMaster("spark://master:7077").setJars(Seq("E:\\Intellij\\Projects\\MachineLearning\\MachineLearning.jar"))
    val sc = new SparkContext(conf)
    Logger.getRootLogger.setLevel(Level.WARN)

    //读取样本数据,生成RDD
    val data_path = "hdfs://master:9000/ml/data/lpsa.data"
    val dataRDD = sc.textFile(data_path)
    val examples = dataRDD.map { line =>
      val parts = line.split(',')
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    }.cache()// 迭代次数
    val numIterations = 100
    // 步长
    val stepSize = 0.5
    // 选取样本的比例
    val miniBatchFraction = 1.0
    // 用随机梯度下降模型训练
    val sgdModel = LinearRegressionWithSGD.train(examples, numIterations, stepSize, miniBatchFraction)

    // 对样本进行测试
    val prediction = sgdModel.predict(examples.map(_.features))
    val predictionAndLabel = prediction.zip(examples.map(_.label))
    // 选取前100个样本
    val show_predict = predictionAndLabel.take(100)
    println("Prediction" + "\t" + "Label" + "\t" + "Diff")
    for (i <- 0 to show_predict.length - 1) {
      val diff = show_predict(i)._1-show_predict(i)._2
      println(show_predict(i)._1 + "\t" + show_predict(i)._2 + "\t" + diff)
    }

  }

}

部分运行结果:

 

posted @ 2017-06-13 18:33  MSTK  阅读(2369)  评论(0编辑  收藏  举报