Spark操作MySQL,Hive并写入MySQL数据库
最近一个项目,需要操作近70亿数据进行统计分析。如果存入MySQL,很难读取如此大的数据,即使使用搜索引擎,也是非常慢。经过调研决定借助我们公司大数据平台结合Spark技术完成这么大数据量的统计分析。
为了后期方便开发人员开发,决定写了几个工具类,屏蔽对MySQL及Hive的操作代码,只需要关心业务代码的编写。
工具类如下:
一. Spark操作MySQL
1. 根据sql语句获取Spark DataFrame:
/** * 从MySql数据库中获取DateFrame * * @param spark SparkSession * @param sql 查询SQL * @return DateFrame */ def getDFFromMysql(spark: SparkSession, sql: String): DataFrame = { println(s"url:${mySqlConfig.url} user:${mySqlConfig.user} sql: ${sql}") spark.read.format("jdbc").option("url", mySqlConfig.url) .option("user", mySqlConfig.user) .option("password", mySqlConfig.password) .option("driver", "com.mysql.jdbc.Driver") .option("query", sql) .load() }
2. 将Spark DataFrame 写入MySQL数据库表
/** * 将结果写入Mysql * @param df DataFrame * @param mode SaveMode * @param tableName SaveMode */ def writeIntoMySql(df: DataFrame, mode: SaveMode, tableName: String): Unit ={ mode match { case SaveMode.Append => appendDataIntoMysql(df, tableName); case SaveMode.Overwrite => overwriteMysqlData(df, tableName); case _ => throw new Exception("目前只支持Append及Overwrite!") } }
/** * 将数据集插入Mysql表 * @param df DataFrame * @param mysqlTableName 表名:database_name.table_name * @return */ def appendDataIntoMysql(df: DataFrame, mysqlTableName: String) = { df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp) }
/** * 将数据集插入Mysql表 * @param df DataFrame * @param mysqlTableName 表名:database_name.table_name * @return */ def overwriteMysqlData(df: DataFrame, mysqlTableName: String) = { //先清除Mysql表中数据 truncateMysqlTable(mysqlTableName) //再往表中追加数据 df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp) }
/** * 删除数据表 * @param mysqlTableName * @return */ def truncateMysqlTable(mysqlTableName: String): Boolean = { val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接 val preparedStatement = conn.createStatement() try { preparedStatement.execute(s"truncate table $mysqlTableName") } catch { case e: Exception => println(s"mysql truncateMysqlTable error:${ExceptionUtil.getExceptionStack(e)}") false } finally { preparedStatement.close() conn.close() }
3. 根据条件删除MySQL表数据
/** * 删除表中的数据 * @param mysqlTableName * @param condition * @return */ def deleteMysqlTableData(mysqlTableName: String, condition: String): Boolean = { val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接 val preparedStatement = conn.createStatement() try { preparedStatement.execute(s"delete from $mysqlTableName where $condition") } catch { case e: Exception => println(s"mysql deleteMysqlTable error:${ExceptionUtil.getExceptionStack(e)}") false } finally { preparedStatement.close() conn.close() } }
4. 保存DataFrame 到 MySQL中,如果表不存在的话,会自动创建
/** * 保存DataFrame 到 MySQL中,如果表不存在的话,会自动创建 * @param tableName * @param resultDateFrame */ def saveDFtoDBCreateTableIfNotExist(tableName: String, resultDateFrame: DataFrame) { //如果没有表,根据DataFrame建表 createTableIfNotExist(tableName, resultDateFrame) //验证数据表字段和dataFrame字段个数和名称,顺序是否一致 verifyFieldConsistency(tableName, resultDateFrame) //保存df saveDFtoDBUsePool(tableName, resultDateFrame) }
/** * 如果数据表不存在,根据DataFrame的字段创建数据表,数据表字段顺序和dataFrame对应 * 若DateFrame出现名为id的字段,将其设为数据库主键(int,自增,主键),其他字段会根据DataFrame的DataType类型来自动映射到MySQL中 * * @param tableName 表名 * @param df dataFrame * @return */ def createTableIfNotExist(tableName: String, df: DataFrame): AnyVal = { val con = MySQLPoolManager.getMysqlManager.getConnection val metaData = con.getMetaData val colResultSet = metaData.getColumns(null, "%", tableName, "%") //如果没有该表,创建数据表 if (!colResultSet.next()) { //构建建表字符串 val sb = new StringBuilder(s"CREATE TABLE `$tableName` (") df.schema.fields.foreach(x => if (x.name.equalsIgnoreCase("id")) { sb.append(s"`${x.name}` int(255) NOT NULL AUTO_INCREMENT PRIMARY KEY,") //如果是字段名为id,设置主键,整形,自增 } else { x.dataType match { case _: ByteType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,") case _: ShortType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,") case _: IntegerType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,") case _: LongType => sb.append(s"`${x.name}` bigint(100) DEFAULT NULL,") case _: BooleanType => sb.append(s"`${x.name}` tinyint DEFAULT NULL,") case _: FloatType => sb.append(s"`${x.name}` float(50) DEFAULT NULL,") case _: DoubleType => sb.append(s"`${x.name}` double(50) DEFAULT NULL,") case _: StringType => sb.append(s"`${x.name}` varchar(50) DEFAULT NULL,") case _: TimestampType => sb.append(s"`${x.name}` timestamp DEFAULT current_timestamp,") case _: DateType => sb.append(s"`${x.name}` date DEFAULT NULL,") case _ => throw new RuntimeException(s"nonsupport ${x.dataType} !!!") } } ) sb.append(") ENGINE=InnoDB DEFAULT CHARSET=utf8") val sql_createTable = sb.deleteCharAt(sb.lastIndexOf(',')).toString() println(sql_createTable) val statement = con.createStatement() statement.execute(sql_createTable) } }
/** * 验证数据表和dataFrame字段个数,名称,顺序是否一致 * * @param tableName 表名 * @param df dataFrame */ def verifyFieldConsistency(tableName: String, df: DataFrame): Unit = { val con = MySQLPoolManager.getMysqlManager.getConnection val metaData = con.getMetaData val colResultSet = metaData.getColumns(null, "%", tableName, "%") colResultSet.last() val tableFiledNum = colResultSet.getRow val dfFiledNum = df.columns.length if (tableFiledNum != dfFiledNum) { throw new Exception(s"数据表和DataFrame字段个数不一致!!table--$tableFiledNum but dataFrame--$dfFiledNum") } for (i <- 1 to tableFiledNum) { colResultSet.absolute(i) val tableFileName = colResultSet.getString("COLUMN_NAME") val dfFiledName = df.columns.apply(i - 1) if (!tableFileName.equals(dfFiledName)) { throw new Exception(s"数据表和DataFrame字段名不一致!!table--'$tableFileName' but dataFrame--'$dfFiledName'") } } colResultSet.beforeFirst() }
/** * 将DataFrame所有类型(除id外)转换为String后,通过c3p0的连接池方法,向mysql写入数据 * * @param tableName 表名 * @param resultDateFrame DataFrame */ def saveDFtoDBUsePool(tableName: String, resultDateFrame: DataFrame) { val colNumbers = resultDateFrame.columns.length val sql = getInsertSql(tableName, colNumbers) val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType) resultDateFrame.foreachPartition(partitionRecords => { val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接 val preparedStatement = conn.prepareStatement(sql) val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%") //通过连接获取表名对应数据表的元数据 try { conn.setAutoCommit(false) partitionRecords.foreach(record => { //注意:setString方法从1开始,record.getString()方法从0开始 for (i <- 1 to colNumbers) { val value = record.get(i - 1) val dateType = columnDataTypes(i - 1) if (value != null) { //如何值不为空,将类型转换为String preparedStatement.setString(i, value.toString) dateType match { case _: ByteType => preparedStatement.setInt(i, record.getAs[Int](i - 1)) case _: ShortType => preparedStatement.setInt(i, record.getAs[Int](i - 1)) case _: IntegerType => preparedStatement.setInt(i, record.getAs[Int](i - 1)) case _: LongType => preparedStatement.setLong(i, record.getAs[Long](i - 1)) case _: BooleanType => preparedStatement.setBoolean(i, record.getAs[Boolean](i - 1)) case _: FloatType => preparedStatement.setFloat(i, record.getAs[Float](i - 1)) case _: DoubleType => preparedStatement.setDouble(i, record.getAs[Double](i - 1)) case _: StringType => preparedStatement.setString(i, record.getAs[String](i - 1)) case _: TimestampType => preparedStatement.setTimestamp(i, record.getAs[Timestamp](i - 1)) case _: DateType => preparedStatement.setDate(i, record.getAs[Date](i - 1)) case _ => throw new RuntimeException(s"nonsupport ${dateType} !!!") } } else { //如果值为空,将值设为对应类型的空值 metaData.absolute(i) preparedStatement.setNull(i, metaData.getInt("DATA_TYPE")) } } preparedStatement.addBatch() }) preparedStatement.executeBatch() conn.commit() } catch { case e: Exception => println(s"@@ saveDFtoDBUsePool error: ${ExceptionUtil.getExceptionStack(e)}") // do some log } finally { preparedStatement.close() conn.close() } }) }
二、操作Spark
1. 切换Spark环境
定义环境Profile.scala
/** * @descrption * scf * @author wangxuexing * @date 2019/12/23 */ object Profile extends Enumeration{ type Profile = Value /** * 生产环境 */ val PROD = Value("prod") /** * 生产测试环境 */ val PROD_TEST = Value("prod_test") /** * 开发环境 */ val DEV = Value("dev") /** * 设置当前环境 */ val currentEvn = PROD }
定义SparkUtil.scala
import com.dmall.scf.Profile import com.dmall.scf.dto.{Env, MySqlConfig} import org.apache.spark.sql.{DataFrame, Encoder, SparkSession} import scala.collection.JavaConversions._ /** * @descrption Spark工具类 * scf * @author wangxuexing * @date 2019/12/23 */ object SparkUtils {
//开发环境
val DEV_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"
val DEV_USER = "user"
val DEV_PASSWORD = "password"
//生产测试环境
val PROD_TEST_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false"
val PROD_TEST_USER = "user"
val PROD_TEST_PASSWORD = "password"
//生产环境
val PROD_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"
val PROD_USER = "user"
val PROD_PASSWORD = "password"
def env = Profile.currentEvn /** * 获取环境设置 * @return */ def getEnv: Env ={ env match { case Profile.DEV => Env(MySqlConfig(DEV_URL, DEV_USER, DEV_PASSWORD), SparkUtils.getDevSparkSession) case Profile.PROD => Env(MySqlConfig(PROD_URL,PROD_USER,PROD_PASSWORD), SparkUtils.getProdSparkSession) case Profile.PROD_TEST => Env(MySqlConfig(PROD_TEST_URL, PROD_TEST_USER, PROD_TEST_PASSWORD), SparkUtils.getProdSparkSession) case _ => throw new Exception("无法获取环境") } } /** * 获取生产SparkSession * @return */ def getProdSparkSession: SparkSession = { SparkSession .builder() .appName("scf") .enableHiveSupport()//激活hive支持 .getOrCreate() } /** * 获取开发SparkSession * @return */ def getDevSparkSession: SparkSession = { SparkSession .builder() .master("local[*]") .appName("local-1576939514234") .config("spark.sql.warehouse.dir", "C:\\data\\spark-ware")//不指定,默认C:\data\projects\parquet2dbs\spark-warehouse .enableHiveSupport()//激活hive支持 .getOrCreate(); } /** * DataFrame 转 case class * @param df DataFrame * @tparam T case class * @return */ def dataFrame2Bean[T: Encoder](df: DataFrame, clazz: Class[T]): List[T] = { val fieldNames = clazz.getDeclaredFields.map(f => f.getName).toList df.toDF(fieldNames: _*).as[T].collectAsList().toList } }
三、定义Spark操作流程
从MySQL或Hive读取数据->逻辑处理->写入MySQL
1. 定义处理流程
SparkAction.scala
import com.dmall.scf.utils.{MySQLUtils, SparkUtils} import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} /** * @descrption 定义Spark处理流程 * @author wangxuexing * @date 2019/12/23 */ trait SparkAction[T] { /** * 定义流程 */ def execute(args: Array[String], spark: SparkSession)={ //1. 前置处理 preAction //2. 处理 val df = action(spark, args) //3. 后置处理 postAction(df) } /** * 前置处理 * @return */ def preAction() = { //无前置处理 } /** * 处理 * @param spark * @return */ def action(spark: SparkSession, args: Array[String]) : DataFrame /** * 后置处理,比如保存结果到Mysql * @param df */ def postAction(df: DataFrame)={ //结果追加到scfc_supplier_run_field_value表 MySQLUtils.writeIntoMySql(df, saveTable._1, saveTable._2) } /** * 保存mode及表名 * @return */ def saveTable: (SaveMode, String) }
2. 实现流程
KanbanAction.scala
import com.dmall.scf.SparkAction import com.dmall.scf.dto.KanbanFieldValue import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} import scala.collection.JavaConverters._ /** * @descrption * scf-spark * @author wangxuexing * @date 2020/1/10 */ trait KanbanAction extends SparkAction[KanbanFieldValue] { /** * 获取datafram * @param resultList * @param spark * @return */ def getDataFrame(resultList: List[KanbanFieldValue], spark: SparkSession): DataFrame= { //根据模式字符串生成模式schema val fields = List(StructField("company_id", LongType, nullable = false), StructField("statistics_date", StringType, nullable = false), StructField("field_id", LongType, nullable = false), StructField("field_type", StringType, nullable = false), StructField("field_value", StringType, nullable = false), StructField("other_value", StringType, nullable = false)) val schema = StructType(fields) //将RDD的记录转换为行 val rowRDD = resultList.map(x=>Row(x.companyId, x.statisticsDate, x.fieldId, x.fieldType, x.fieldValue, x.otherValue)).asJava //RDD转为DataFrame spark.createDataFrame(rowRDD, schema) } /** * 保存mode及表名 * * @return */ override def saveTable: (SaveMode, String) = (SaveMode.Append, "scfc_kanban_field_value") }
3. 实现具体业务逻辑
import com.dmall.scf.dto.{KanbanFieldValue, RegisteredMoney} import com.dmall.scf.utils.{DateUtils, MySQLUtils} import org.apache.spark.sql.{DataFrame, SparkSession} /** * @descrption * scf-spark 注册资本分布 * @author wangxuexing * @date 2020/1/10 */ object RegMoneyDistributionAction extends KanbanAction{ val CLASS_NAME = this.getClass.getSimpleName().filter(!_.equals('$')) val RANGE_50W = BigDecimal(50) val RANGE_100W = BigDecimal(100) val RANGE_500W = BigDecimal(500) val RANGE_1000W = BigDecimal(1000) /** * 处理 * * @param spark * @return */ override def action(spark: SparkSession, args: Array[String]): DataFrame = { import spark.implicits._ if(args.length < 2){ throw new Exception("请指定是当前年(值为1)还是去年(值为2):1|2") } val lastDay = DateUtils.addSomeDays(-1) val (starDate, endDate, filedId) = args(1) match { case "1" => val startDate = DateUtils.isFirstDayOfYear match { case true => DateUtils.getFirstDateOfLastYear case false => DateUtils.getFirstDateOfCurrentYear } (startDate, DateUtils.formatNormalDateStr(lastDay), 44) case "2" => val startDate = DateUtils.isFirstDayOfYear match { case true => DateUtils.getLast2YearFirstStr(DateUtils.YYYY_MM_DD) case false => DateUtils.getLastYearFirstStr(DateUtils.YYYY_MM_DD) } val endDate = DateUtils.isFirstDayOfYear match { case true => DateUtils.getLast2YearLastStr(DateUtils.YYYY_MM_DD) case false => DateUtils.getLastYearLastStr(DateUtils.YYYY_MM_DD) } (startDate, endDate, 45) case _ => throw new Exception("请传入正确的参数:是当前年(值为1)还是去年(值为2):1|2") } val sql = s"""SELECT id, IFNULL(registered_money, 0) registered_money FROM scfc_supplier_info WHERE `status` = 3 AND yn = 1""" val allDimension = MySQLUtils.getDFFromMysql(spark, sql) val beanList = allDimension.map(x => RegisteredMoney(x.getLong(0), x.getDecimal(1))) //val filterList = SparkUtils.dataFrame2Bean[RegisteredMoney](allDimension, classOf[RegisteredMoney]) val hiveSql = s""" SELECT DISTINCT(a.company_id) supplier_ids FROM wm_ods_cx_supplier_card_info a JOIN wm_ods_jrbl_loan_dkzhxx b ON a.card_code = b.gshkahao WHERE a.audit_status = '2' AND b.jiluztai = '0' AND to_date(b.gxinshij)>= '${starDate}' AND to_date(b.gxinshij)<= '${endDate}'""" println(hiveSql) val supplierIds = spark.sql(hiveSql).collect().map(_.getLong(0)) val filterList = beanList.filter(x => supplierIds.contains(x.supplierId)) val range1 = spark.sparkContext.collectionAccumulator[Int] val range2 = spark.sparkContext.collectionAccumulator[Int] val range3 = spark.sparkContext.collectionAccumulator[Int] val range4 = spark.sparkContext.collectionAccumulator[Int] val range5 = spark.sparkContext.collectionAccumulator[Int] filterList.foreach(x => { if(RANGE_50W.compare(x.registeredMoney) >= 0){ range1.add(1) } else if (RANGE_50W.compare(x.registeredMoney) < 0 && RANGE_100W.compare(x.registeredMoney) >= 0){ range1.add(1) } else if (RANGE_100W.compare(x.registeredMoney) < 0 && RANGE_500W.compare(x.registeredMoney) >= 0){ range2.add(1) } else if (RANGE_500W.compare(x.registeredMoney) < 0 && RANGE_1000W.compare(x.registeredMoney) >= 0){ range3.add(1) } else if (RANGE_1000W.compare(x.registeredMoney) < 0){ range4.add(1) } }) val resultList = List(("50万元以下", range1.value.size()), ("50-100万元", range2.value.size()), ("100-500万元", range3.value.size()),("500-1000万元", range4.value.size()), ("1000万元以上", range5.value.size())).map(x => { KanbanFieldValue(1, lastDay, filedId, x._1, x._2.toString, "") }) getDataFrame(resultList, spark) } }
具体项目源码请参考:
https://github.com/barrywang88/spark-tool
https://gitee.com/barrywang/spark-tool
每天一点成长,欢迎指正!