Spark操作MySQL,Hive并写入MySQL数据库

最近一个项目,需要操作近70亿数据进行统计分析。如果存入MySQL,很难读取如此大的数据,即使使用搜索引擎,也是非常慢。经过调研决定借助我们公司大数据平台结合Spark技术完成这么大数据量的统计分析。

为了后期方便开发人员开发,决定写了几个工具类,屏蔽对MySQL及Hive的操作代码,只需要关心业务代码的编写。

工具类如下:

一. Spark操作MySQL

1. 根据sql语句获取Spark DataFrame:

  /**
   * 从MySql数据库中获取DateFrame
   *
   * @param spark     SparkSession
   * @param sql 查询SQL
   * @return DateFrame
   */
  def getDFFromMysql(spark: SparkSession, sql: String): DataFrame = {
    println(s"url:${mySqlConfig.url} user:${mySqlConfig.user} sql: ${sql}")
    spark.read.format("jdbc").option("url", mySqlConfig.url)
      .option("user", mySqlConfig.user)
      .option("password", mySqlConfig.password)
      .option("driver", "com.mysql.jdbc.Driver")
      .option("query", sql) .load()
  }

 

2. 将Spark DataFrame 写入MySQL数据库表

  /**
   * 将结果写入Mysql
   * @param df DataFrame
   * @param mode SaveMode
   * @param tableName SaveMode
   */
  def writeIntoMySql(df: DataFrame, mode: SaveMode, tableName: String): Unit ={
    mode match {
      case SaveMode.Append => appendDataIntoMysql(df, tableName);
      case SaveMode.Overwrite => overwriteMysqlData(df, tableName);
      case _ => throw new Exception("目前只支持Append及Overwrite!")
    }
  }
  /**
   * 将数据集插入Mysql表
   * @param df DataFrame
   * @param mysqlTableName 表名:database_name.table_name
   * @return
   */
  def appendDataIntoMysql(df: DataFrame, mysqlTableName: String) = {
    df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp)
  }
  /**
   * 将数据集插入Mysql表
   * @param df DataFrame
   * @param mysqlTableName 表名:database_name.table_name
   * @return
   */
  def overwriteMysqlData(df: DataFrame, mysqlTableName: String) = {
    //先清除Mysql表中数据
    truncateMysqlTable(mysqlTableName)
    //再往表中追加数据
    df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp)
  }
  /**
   * 删除数据表
   * @param mysqlTableName
   * @return
   */
  def truncateMysqlTable(mysqlTableName: String): Boolean = {
    val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接
    val preparedStatement = conn.createStatement()
    try {
      preparedStatement.execute(s"truncate table $mysqlTableName")
    } catch {
      case e: Exception =>
        println(s"mysql truncateMysqlTable error:${ExceptionUtil.getExceptionStack(e)}")
        false
    } finally {
      preparedStatement.close()
      conn.close()
    }

 3. 根据条件删除MySQL表数据

  /**
    *  删除表中的数据
    * @param mysqlTableName
    * @param condition
    * @return
    */
  def deleteMysqlTableData(mysqlTableName: String, condition: String): Boolean = {
    val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接
    val preparedStatement = conn.createStatement()
    try {
      preparedStatement.execute(s"delete from $mysqlTableName where $condition")
    } catch {
      case e: Exception =>
        println(s"mysql deleteMysqlTable error:${ExceptionUtil.getExceptionStack(e)}")
        false
    } finally {
      preparedStatement.close()
      conn.close()
    }
  }

4. 保存DataFrame 到 MySQL中,如果表不存在的话,会自动创建

/**
    * 保存DataFrame 到 MySQL中,如果表不存在的话,会自动创建
    * @param tableName
    * @param resultDateFrame
    */
  def saveDFtoDBCreateTableIfNotExist(tableName: String, resultDateFrame: DataFrame) {
    //如果没有表,根据DataFrame建表
    createTableIfNotExist(tableName, resultDateFrame)
    //验证数据表字段和dataFrame字段个数和名称,顺序是否一致
    verifyFieldConsistency(tableName, resultDateFrame)
    //保存df
    saveDFtoDBUsePool(tableName, resultDateFrame)
  }
  /**
    * 如果数据表不存在,根据DataFrame的字段创建数据表,数据表字段顺序和dataFrame对应
    * 若DateFrame出现名为id的字段,将其设为数据库主键(int,自增,主键),其他字段会根据DataFrame的DataType类型来自动映射到MySQL中
    *
    * @param tableName 表名
    * @param df        dataFrame
    * @return
    */
  def createTableIfNotExist(tableName: String, df: DataFrame): AnyVal = {
    val con = MySQLPoolManager.getMysqlManager.getConnection
    val metaData = con.getMetaData
    val colResultSet = metaData.getColumns(null, "%", tableName, "%")
    //如果没有该表,创建数据表
    if (!colResultSet.next()) {
      //构建建表字符串
      val sb = new StringBuilder(s"CREATE TABLE `$tableName` (")
      df.schema.fields.foreach(x =>
        if (x.name.equalsIgnoreCase("id")) {
          sb.append(s"`${x.name}` int(255) NOT NULL AUTO_INCREMENT PRIMARY KEY,") //如果是字段名为id,设置主键,整形,自增
        } else {
          x.dataType match {
            case _: ByteType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,")
            case _: ShortType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,")
            case _: IntegerType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,")
            case _: LongType => sb.append(s"`${x.name}` bigint(100) DEFAULT NULL,")
            case _: BooleanType => sb.append(s"`${x.name}` tinyint DEFAULT NULL,")
            case _: FloatType => sb.append(s"`${x.name}` float(50) DEFAULT NULL,")
            case _: DoubleType => sb.append(s"`${x.name}` double(50) DEFAULT NULL,")
            case _: StringType => sb.append(s"`${x.name}` varchar(50) DEFAULT NULL,")
            case _: TimestampType => sb.append(s"`${x.name}` timestamp DEFAULT current_timestamp,")
            case _: DateType => sb.append(s"`${x.name}` date  DEFAULT NULL,")
            case _ => throw new RuntimeException(s"nonsupport ${x.dataType} !!!")
          }
        }
      )
      sb.append(") ENGINE=InnoDB DEFAULT CHARSET=utf8")
      val sql_createTable = sb.deleteCharAt(sb.lastIndexOf(',')).toString()
      println(sql_createTable)
      val statement = con.createStatement()
      statement.execute(sql_createTable)
    }
  }
  /**
    * 验证数据表和dataFrame字段个数,名称,顺序是否一致
    *
    * @param tableName 表名
    * @param df        dataFrame
    */
  def verifyFieldConsistency(tableName: String, df: DataFrame): Unit = {
    val con = MySQLPoolManager.getMysqlManager.getConnection
    val metaData = con.getMetaData
    val colResultSet = metaData.getColumns(null, "%", tableName, "%")
    colResultSet.last()
    val tableFiledNum = colResultSet.getRow
    val dfFiledNum = df.columns.length
    if (tableFiledNum != dfFiledNum) {
      throw new Exception(s"数据表和DataFrame字段个数不一致!!table--$tableFiledNum but dataFrame--$dfFiledNum")
    }
    for (i <- 1 to tableFiledNum) {
      colResultSet.absolute(i)
      val tableFileName = colResultSet.getString("COLUMN_NAME")
      val dfFiledName = df.columns.apply(i - 1)
      if (!tableFileName.equals(dfFiledName)) {
        throw new Exception(s"数据表和DataFrame字段名不一致!!table--'$tableFileName' but dataFrame--'$dfFiledName'")
      }
    }
    colResultSet.beforeFirst()
  }
/**
    * 将DataFrame所有类型(除id外)转换为String后,通过c3p0的连接池方法,向mysql写入数据
    *
    * @param tableName       表名
    * @param resultDateFrame DataFrame
    */
  def saveDFtoDBUsePool(tableName: String, resultDateFrame: DataFrame) {
    val colNumbers = resultDateFrame.columns.length
    val sql = getInsertSql(tableName, colNumbers)
    val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
    resultDateFrame.foreachPartition(partitionRecords => {
      val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接
      val preparedStatement = conn.prepareStatement(sql)
      val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%") //通过连接获取表名对应数据表的元数据
      try {
        conn.setAutoCommit(false)
        partitionRecords.foreach(record => {
          //注意:setString方法从1开始,record.getString()方法从0开始
          for (i <- 1 to colNumbers) {
            val value = record.get(i - 1)
            val dateType = columnDataTypes(i - 1)
            if (value != null) { //如何值不为空,将类型转换为String
              preparedStatement.setString(i, value.toString)
              dateType match {
                case _: ByteType => preparedStatement.setInt(i, record.getAs[Int](i - 1))
                case _: ShortType => preparedStatement.setInt(i, record.getAs[Int](i - 1))
                case _: IntegerType => preparedStatement.setInt(i, record.getAs[Int](i - 1))
                case _: LongType => preparedStatement.setLong(i, record.getAs[Long](i - 1))
                case _: BooleanType => preparedStatement.setBoolean(i, record.getAs[Boolean](i - 1))
                case _: FloatType => preparedStatement.setFloat(i, record.getAs[Float](i - 1))
                case _: DoubleType => preparedStatement.setDouble(i, record.getAs[Double](i - 1))
                case _: StringType => preparedStatement.setString(i, record.getAs[String](i - 1))
                case _: TimestampType => preparedStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
                case _: DateType => preparedStatement.setDate(i, record.getAs[Date](i - 1))
                case _ => throw new RuntimeException(s"nonsupport ${dateType} !!!")
              }
            } else { //如果值为空,将值设为对应类型的空值
              metaData.absolute(i)
              preparedStatement.setNull(i, metaData.getInt("DATA_TYPE"))
            }
          }
          preparedStatement.addBatch()
        })
        preparedStatement.executeBatch()
        conn.commit()
      } catch {
        case e: Exception => println(s"@@ saveDFtoDBUsePool error: ${ExceptionUtil.getExceptionStack(e)}")
        // do some log
      } finally {
        preparedStatement.close()
        conn.close()
      }
    })
  }

 二、操作Spark

1. 切换Spark环境

定义环境Profile.scala

/**
 * @descrption
 * scf
 * @author wangxuexing
 * @date 2019/12/23
 */
object Profile extends Enumeration{
  type Profile = Value
  /**
   * 生产环境
   */
  val PROD = Value("prod")
  /**
   * 生产测试环境
   */
  val PROD_TEST = Value("prod_test")
  /**
   * 开发环境
   */
  val DEV = Value("dev")

  /**
   * 设置当前环境
   */
  val currentEvn = PROD
}

 

定义SparkUtil.scala

import com.dmall.scf.Profile
import com.dmall.scf.dto.{Env, MySqlConfig}
import org.apache.spark.sql.{DataFrame, Encoder, SparkSession}

import scala.collection.JavaConversions._

/**
 * @descrption Spark工具类
 * scf
 * @author wangxuexing
 * @date 2019/12/23
 */
object SparkUtils {
//开发环境

val DEV_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"

val DEV_USER = "user"

val DEV_PASSWORD = "password"

//生产测试环境

val PROD_TEST_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false"

val PROD_TEST_USER = "user"

val PROD_TEST_PASSWORD = "password"

//生产环境

val PROD_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"

val PROD_USER = "user"

val PROD_PASSWORD = "password"


  def env = Profile.currentEvn

  /**
   * 获取环境设置
   * @return
   */
  def getEnv: Env ={
    env match {
      case Profile.DEV => Env(MySqlConfig(DEV_URL, DEV_USER, DEV_PASSWORD), SparkUtils.getDevSparkSession)
      case Profile.PROD =>
        Env(MySqlConfig(PROD_URL,PROD_USER,PROD_PASSWORD), SparkUtils.getProdSparkSession)
      case Profile.PROD_TEST =>
        Env(MySqlConfig(PROD_TEST_URL, PROD_TEST_USER, PROD_TEST_PASSWORD), SparkUtils.getProdSparkSession)
      case _ => throw new Exception("无法获取环境")
    }
  }

  /**
   * 获取生产SparkSession
   * @return
   */
  def getProdSparkSession: SparkSession = {
    SparkSession
      .builder()
      .appName("scf")
      .enableHiveSupport()//激活hive支持
      .getOrCreate()
  }

  /**
   * 获取开发SparkSession
   * @return
   */
  def getDevSparkSession: SparkSession = {
    SparkSession
      .builder()
      .master("local[*]")
      .appName("local-1576939514234")
      .config("spark.sql.warehouse.dir", "C:\\data\\spark-ware")//不指定,默认C:\data\projects\parquet2dbs\spark-warehouse
      .enableHiveSupport()//激活hive支持
      .getOrCreate();
  }

  /**
   * DataFrame 转 case class
   * @param df DataFrame
   * @tparam T case class
   * @return
   */
  def dataFrame2Bean[T: Encoder](df: DataFrame, clazz: Class[T]): List[T] = {
    val fieldNames = clazz.getDeclaredFields.map(f => f.getName).toList
    df.toDF(fieldNames: _*).as[T].collectAsList().toList
  }
}

三、定义Spark操作流程

从MySQL或Hive读取数据->逻辑处理->写入MySQL

1. 定义处理流程

SparkAction.scala

import com.dmall.scf.utils.{MySQLUtils, SparkUtils}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

/**
 * @descrption 定义Spark处理流程
 * @author wangxuexing
 * @date 2019/12/23
 */
trait SparkAction[T] {
  /**
   * 定义流程
   */
  def execute(args: Array[String], spark: SparkSession)={
    //1. 前置处理
    preAction
    //2. 处理
    val df = action(spark, args)
    //3. 后置处理
    postAction(df)
  }

  /**
   * 前置处理
   * @return
   */
  def preAction() = {
    //无前置处理
  }

  /**
   * 处理
   * @param spark
   * @return
   */
  def action(spark: SparkSession, args: Array[String]) : DataFrame

  /**
   * 后置处理,比如保存结果到Mysql
   * @param df
   */
  def postAction(df: DataFrame)={
    //结果追加到scfc_supplier_run_field_value表
    MySQLUtils.writeIntoMySql(df, saveTable._1, saveTable._2)
  }

  /**
   * 保存mode及表名
   * @return
   */
  def saveTable: (SaveMode, String)
}

2. 实现流程

KanbanAction.scala

import com.dmall.scf.SparkAction
import com.dmall.scf.dto.KanbanFieldValue
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}

import scala.collection.JavaConverters._

/**
 * @descrption
 * scf-spark
 * @author wangxuexing
 * @date 2020/1/10
 */
trait KanbanAction extends SparkAction[KanbanFieldValue] {
  /**
   * 获取datafram
   * @param resultList
   * @param spark
   * @return
   */
  def getDataFrame(resultList: List[KanbanFieldValue], spark: SparkSession): DataFrame= {
    //根据模式字符串生成模式schema
    val fields = List(StructField("company_id", LongType, nullable = false),
      StructField("statistics_date", StringType, nullable = false),
      StructField("field_id", LongType, nullable = false),
      StructField("field_type", StringType, nullable = false),
      StructField("field_value", StringType, nullable = false),
      StructField("other_value", StringType, nullable = false))
    val schema = StructType(fields)
    //将RDD的记录转换为行
    val rowRDD = resultList.map(x=>Row(x.companyId, x.statisticsDate, x.fieldId, x.fieldType, x.fieldValue, x.otherValue)).asJava
    //RDD转为DataFrame
    spark.createDataFrame(rowRDD, schema)
  }
  /**
   * 保存mode及表名
   *
   * @return
   */
  override def saveTable: (SaveMode, String) = (SaveMode.Append, "scfc_kanban_field_value")
}

3. 实现具体业务逻辑

import com.dmall.scf.dto.{KanbanFieldValue, RegisteredMoney}
import com.dmall.scf.utils.{DateUtils, MySQLUtils}
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
 * @descrption
 * scf-spark 注册资本分布
 * @author wangxuexing
 * @date 2020/1/10
 */
object RegMoneyDistributionAction extends KanbanAction{
  val CLASS_NAME = this.getClass.getSimpleName().filter(!_.equals('$'))

  val RANGE_50W = BigDecimal(50)
  val RANGE_100W = BigDecimal(100)
  val RANGE_500W = BigDecimal(500)
  val RANGE_1000W = BigDecimal(1000)

  /**
   * 处理
   *
   * @param spark
   * @return
   */
  override def action(spark: SparkSession, args: Array[String]): DataFrame = {
    import spark.implicits._
    if(args.length < 2){
      throw new Exception("请指定是当前年(值为1)还是去年(值为2):1|2")
    }
    val lastDay = DateUtils.addSomeDays(-1)
    val (starDate, endDate, filedId) = args(1) match {
      case "1" =>
        val startDate = DateUtils.isFirstDayOfYear match {
          case true => DateUtils.getFirstDateOfLastYear
          case false => DateUtils.getFirstDateOfCurrentYear
        }

        (startDate, DateUtils.formatNormalDateStr(lastDay), 44)
      case "2" =>
        val startDate = DateUtils.isFirstDayOfYear match {
          case true => DateUtils.getLast2YearFirstStr(DateUtils.YYYY_MM_DD)
          case false => DateUtils.getLastYearFirstStr(DateUtils.YYYY_MM_DD)
        }
        val endDate = DateUtils.isFirstDayOfYear match {
          case true => DateUtils.getLast2YearLastStr(DateUtils.YYYY_MM_DD)
          case false => DateUtils.getLastYearLastStr(DateUtils.YYYY_MM_DD)
        }
        (startDate, endDate, 45)
      case _ =>  throw new Exception("请传入正确的参数:是当前年(值为1)还是去年(值为2):1|2")
    }

    val sql = s"""SELECT
                      id,
                      IFNULL(registered_money, 0) registered_money
                    FROM
                      scfc_supplier_info
                    WHERE
                      `status` = 3
                    AND yn = 1"""
    val allDimension = MySQLUtils.getDFFromMysql(spark, sql)
    val beanList =  allDimension.map(x => RegisteredMoney(x.getLong(0), x.getDecimal(1)))
    //val filterList =  SparkUtils.dataFrame2Bean[RegisteredMoney](allDimension, classOf[RegisteredMoney])
    val hiveSql = s"""
                   SELECT DISTINCT(a.company_id) supplier_ids
                    FROM wm_ods_cx_supplier_card_info a
                    JOIN wm_ods_jrbl_loan_dkzhxx b ON a.card_code = b.gshkahao
                    WHERE a.audit_status = '2'
                      AND b.jiluztai = '0'
                      AND to_date(b.gxinshij)>= '${starDate}'
                      AND to_date(b.gxinshij)<= '${endDate}'"""
    println(hiveSql)
    val supplierIds = spark.sql(hiveSql).collect().map(_.getLong(0))
    val filterList = beanList.filter(x => supplierIds.contains(x.supplierId))

    val range1 =  spark.sparkContext.collectionAccumulator[Int]
    val range2 =  spark.sparkContext.collectionAccumulator[Int]
    val range3 =  spark.sparkContext.collectionAccumulator[Int]
    val range4 =  spark.sparkContext.collectionAccumulator[Int]
    val range5 =  spark.sparkContext.collectionAccumulator[Int]
    filterList.foreach(x => {
      if(RANGE_50W.compare(x.registeredMoney) >= 0){
        range1.add(1)
      } else if (RANGE_50W.compare(x.registeredMoney) < 0 && RANGE_100W.compare(x.registeredMoney) >= 0){
        range1.add(1)
      } else if (RANGE_100W.compare(x.registeredMoney) < 0 && RANGE_500W.compare(x.registeredMoney) >= 0){
        range2.add(1)
      } else if (RANGE_500W.compare(x.registeredMoney) < 0 && RANGE_1000W.compare(x.registeredMoney) >= 0){
        range3.add(1)
      } else if (RANGE_1000W.compare(x.registeredMoney) < 0){
        range4.add(1)
      }
    })
    val resultList = List(("50万元以下", range1.value.size()), ("50-100万元", range2.value.size()),
                          ("100-500万元", range3.value.size()),("500-1000万元", range4.value.size()),
                          ("1000万元以上", range5.value.size())).map(x => {
      KanbanFieldValue(1, lastDay, filedId, x._1, x._2.toString, "")
    })

    getDataFrame(resultList, spark)
  }
}

 具体项目源码请参考:

https://github.com/barrywang88/spark-tool

https://gitee.com/barrywang/spark-tool

 

posted @ 2020-02-18 10:52  BarryW  阅读(7667)  评论(10编辑  收藏  举报