在spark mapPartition中使用迭代器进行优化

一般在使用mapPartition时,往往会跟随着文件的创建或者数据库的连接等,此时我们需要在创建一个容器,用于存储维表关联后的数据,但这有一个问题,创建的容器会占用内存的,这时我们可以使用迭代器进行优化。

 

一、普遍方法

package org.shydow

import java.sql.{Connection, PreparedStatement}

import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.shydow.DBPool.MySQLPoolManager

import scala.collection.mutable.ListBuffer

/**
 * @author shydow
 * @date 2021-12-13
 * @desc mapPartition一般使用方法
 */
object TestMapPartition {

  case class Event(eventId: String, eventName: String, pv: Long, stayTime: String)

  def main(args: Array[String]): Unit = {

    val conf: SparkConf = new SparkConf().setAppName("test-mapPartition").setMaster("local[*]")
    val sc = new SparkContext(conf)

    val lines: RDD[String] = sc.textFile("/app/event_log.txt", 4)
    lines.mapPartitions { it =>
      val conn: Connection = MySQLPoolManager.getMySQLManager.getConnection
      val ps: PreparedStatement = conn.prepareStatement("select event_name from dim_event_info where event_id = ?")
      val list: ListBuffer[Event] = ListBuffer[Event]()
      while (it.hasNext) {
        val line: String = it.next()
        val arr: Array[String] = line.split(",")
        ps.setString(1, arr(0))
     val res = ps.executEQuery()
        var eventName: String = null
     while(res.next){
      eventName = res.getString("event_name")
        }
        list.append(Event(arr(0), eventName, arr(2).toLong, arr(3)))
      }
      list.toIterator
    }

    sc.stop()
  }
}

 

 

二、使用迭代器进行优化

package org.shydow

import java.sql.{Connection, PreparedStatement, ResultSet}

import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
import org.shydow.DBPool.MySQLPoolManager


/**
 * @author shydow
 * @date 2021-12-13
 * @desc 测试mapPartition中进行维表关联时使用迭代器进行优化
 */
object TestMapPartition {

  def main(args: Array[String]): Unit = {

    val conf: SparkConf = new SparkConf().setAppName("test-mapPartition").setMaster("local[*]")
    val sc = new SparkContext(conf)

    val lines: RDD[String] = sc.textFile("/app/event_log.txt", 4)
    lines.mapPartitions(new LookupEventIter(_))

    sc.stop()
  }
}

case class Event(eventId: String, eventName: String, pv: Long, stayTime: String)

class LookupEventIter(it: Iterator[String]) extends Iterator[Event] {

  private val conn: Connection = MySQLPoolManager.getMySQLManager.getConnection
  private val ps: PreparedStatement = conn.prepareStatement("select event_name from dim_event_info where event_id = ?")

  override def hasNext: Boolean = {
    if (it.hasNext) true
    else {
      ps.close()
      conn.close()
      false
    }
  }

  override def next(): Event = {
    val line: String = it.next()
    val arr: Array[String] = line.split(",")
    var eventName: String = null
    ps.setString(1, arr(0))
    val res: ResultSet = ps.executeQuery()
    while (res.next()) {
      eventName = res.getString("event_name")
    }
    Event(arr(0), eventName, arr(2).toLong, arr(3))
  }
}

 

三、数据库连接池,使用cpd3

package org.shydow.DBPool

import java.sql.Connection

/**
 * @author shydow
 * @date 2021-10-09
 */

class MySQLPool extends Serializable {

  private val cpd = new ComboPooledDataSource(true)
  try {
    cpd.setJdbcUrl(Constants.MYSQL_URL)
    cpd.setDriverClass(Constants.MYSQL_DRIVER)
    cpd.setUser(Constants.MYSQL_USER)
    cpd.setPassword(Constants.MYSQL_PASSWORD)
    cpd.setAcquireIncrement(Constants.MYSQL_AC)
    cpd.setMinPoolSize(Constants.MYSQL_MINPS)
    cpd.setMaxPoolSize(Constants.MYSQL_MAXPS)
    cpd.setMaxStatements(Constants.MYSQL_MAXST)
  } catch {
    case e: Exception => e.printStackTrace()
  }

  def getConnection: Connection = {
    try {
      cpd.getConnection()
    } catch {
      case e: Exception =>
        e.printStackTrace()
        null
    }
  }

  def close(): Unit = {
    try {
      cpd.close()
    } catch {
      case e: Exception => e.printStackTrace()
    }
  }
}
package org.shydow.DBPool

/**
 * @author shydow
 * @date 2021-10-09
 */

object MySQLPoolManager {
  var mm: MySQLPool = _

  def getMySQLManager: MySQLPool = {
    synchronized {
      if (mm == null) {
        mm = new MySQLPool
      }
    }
    mm
  }
}

 

 

posted @ 2021-12-13 11:29  Shydow  阅读(412)  评论(0编辑  收藏  举报