[转]Spark自定义HBase数据源

转发原文:Spark自定义HBase数据源

Spark内置很多数据源,却没有HBase的数据源,需要调用rdd的api,如果能有下面这种方式就非常完美了:

 frame.write.format("hbase")
      .mode(SaveMode.Append)
      .option(ZK_HOST_HBASE, "bigdata.cn")
      .option(ZK_PORT_HBASE, 2181)
      .option(HBASE_TABLE, "tbl_users_2")
      .option(HBASE_FAMILY, "detail")
      .option(HBASE_ROW_KEY, "id")
      .option(HBASE_SELECT_WHERE_CONDITIONS,"id[ge]50")
      .save()

参考spark sql内置的数据源和第三方提供的数据源(如Kudu、Elaticsearch),自定义HBase数据源。

一,思路

1,编写两个类HBaseRelation和DefaultSource。

HBaseRelation(数据源):负责实现读写HBase的逻辑,继承两个接口,TableScan(读),InsertableRelation(写)

DefaultSource: 工厂类,创建HBaseRelation提供给框架调用,实现两个接口,RelationProvider(创建TableScan对象),CreatableRelationProvider(创建InsertableRelation对象)。如果要用简称format("hbase"),还需要实现DataSourceRegister

这两个类的继承体系如下,上面代码中的format("hbase")指的是DefaultSource,通过DefaultSource创建HBaseRelation对象,再通过HBaseRelation对象读写HBase。

2,注册

format("hbase")中的hbase是DefaultSource的代称,要让系统知道这种关联关系,需要做两件事。

一是DefaultSource继承DataSourceRegister,指定简称。

override def shortName(): String = "hbase"


二是通过resources目录下的配置文件注册,特别注意目录是resources下的META-INF下的services目录。

二,代码

Condition条件类

该类解析过滤条件,.option(HBASE_SELECT_WHERE_CONDITIONS,"id[ge]50"),意思是查询id大于等于50的数据

package com.lcy.models.datasource

import org.apache.hadoop.hbase.filter.CompareFilter.CompareOp

import scala.util.matching.Regex

case class Condition(field:String,op:CompareOp,value:String)
object Condition{
  val reg = "(.*)\\[(.*)\\](.*)".r

  def parse(condition:String) = {

    val matches: Iterator[Regex.Match] = reg.findAllMatchIn(condition)
    val matchR: Regex.Match = matches.toList(0)

    val field = matchR.group(1)
    val op = matchR.group(2)
    val value = matchR.group(3)

    val compareOp: CompareOp = op match {
        case "ge" => CompareOp.GREATER_OR_EQUAL
        case "gt" => CompareOp.GREATER
        case "eq" => CompareOp.EQUAL
        case "le" => CompareOp.LESS_OR_EQUAL
        case "lt" => CompareOp.LESS
        case "neq" => CompareOp.NOT_EQUAL
        case _ => throw new RuntimeException("not supported operator")
    }

    Condition(field,compareOp,value)

  }
}

DefaultSource:

package com.lcy.models.datasource

import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class DefaultSource
  extends RelationProvider
    with CreatableRelationProvider
    with DataSourceRegister
{
  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
    println("xxxxx")
    val relation = new HBaseRelation(sqlContext, data.schema, parameters)
    // todo,需要主动调用insert方法,和保存不一样,保存是自动调buildScan方法
    relation.insert(data,true)
    relation
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    val fields = parameters.getOrElse("fields.selecting","").split(",");
    val schema = new StructType(
      fields.map(field => {
        StructField(field, StringType, false)
      })
    )

    // todo,schema不能传空,哪怕后面不使用
    new HBaseRelation(sqlContext,schema,parameters)
  }

  override def shortName(): String = "hbase"
}

HBaseRelation:

package com.lcy.models.datasource

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hbase.HBaseConfiguration
import org.apache.hadoop.hbase.client.{Put, Result, Scan}
import org.apache.hadoop.hbase.filter.{FilterList, SingleColumnValueFilter}
import org.apache.hadoop.hbase.io.ImmutableBytesWritable
import org.apache.hadoop.hbase.mapreduce.{TableInputFormat, TableOutputFormat}
import org.apache.hadoop.hbase.protobuf.ProtobufUtil
import org.apache.hadoop.hbase.util.{Base64, Bytes}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.StructType

import scala.util.matching.Regex

class HBaseRelation(sqlCxt: SQLContext, structType: StructType, parameters: Map[String, String])
  extends BaseRelation with TableScan with InsertableRelation with Serializable{

  val ZK_HOST_HBASE = "hbase.zookeeper.quorum"
  val ZK_PORT_HBASE = "hbase.zookeeper.property.clientPort"
  val HBASE_FAMILY ="family.name"
  val HBASE_ROW_KEY ="row.key.name"
  val HBASE_SELECT_FIELDS ="fields.selecting"
  val HBASE_SELECT_WHERE_CONDITIONS ="where.conditions"
  val HBASE_TABLE ="hbase.table"
  override def sqlContext: SQLContext = sqlCxt

  override def schema: StructType = structType

  implicit def toBytes(str:String) = {
    Bytes.toBytes(str)
  }

  override def buildScan(): RDD[Row] = {
    // * 创建HBaseConf对象
    val conf: Configuration = HBaseConfiguration.create()
    conf.set(ZK_HOST_HBASE, fromParameters(ZK_HOST_HBASE))
    conf.set(ZK_PORT_HBASE, fromParameters(ZK_PORT_HBASE))
    conf.set(TableInputFormat.INPUT_TABLE, fromParameters(HBASE_TABLE))

    // * 根据连接属性创建Scan对象
    val scan = new Scan()
    val familyName: String = fromParameters(HBASE_FAMILY)
    scan.addFamily(familyName)
    val fieldsSelecting: String = fromParameters(HBASE_SELECT_FIELDS)
    val fields: Array[String] = fieldsSelecting.split(",")
    fields.foreach(field=>{
      scan.addColumn(familyName,field)
    })

    // * 解析where条件
    val whereConditions: String = fromParameters(HBASE_SELECT_WHERE_CONDITIONS)
    if (whereConditions != "") {

      val filterList: FilterList = new FilterList
      val conditions: Array[String] = whereConditions.split(",")
      conditions.foreach(condition=>{
        val conditionObj: Condition = Condition.parse(condition)
        filterList.addFilter(new SingleColumnValueFilter(familyName,conditionObj.field,conditionObj.op,conditionObj.value))

        // todo 这里非常重要,容易忘记,filter是对查出来的结果集进行过滤,所以过滤的列必须包含在结果集中
        scan.addColumn(familyName,conditionObj.field)
      })
      scan.setFilter(filterList)
    }



    conf.set(
      TableInputFormat.SCAN,
      Base64.encodeBytes(ProtobufUtil.toScan(scan).toByteArray)
    )
    // * 调用newHadoopApi读取数据
    sqlContext
      .sparkContext
      .newAPIHadoopRDD(conf,
         classOf[TableInputFormat],
         classOf[ImmutableBytesWritable],
         classOf[Result])

      // * 将读取的数据转化为RDD[Row]
      .mapPartitions(
        it=>{
          it.map(res=> {
            Row.fromSeq(
              fields.map(field=>{
                // todo 这里要注意,不转成字符串,会在shou时报错
                Bytes.toString(res._2.getValue(familyName, field))
              })
            )
          })
        }
      )


  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    // 创建Conf对象
    val conf: Configuration = HBaseConfiguration.create()
    conf.set(ZK_HOST_HBASE, fromParameters(ZK_HOST_HBASE))
    conf.set(ZK_PORT_HBASE, fromParameters(ZK_PORT_HBASE))
    conf.set(TableOutputFormat.OUTPUT_TABLE,fromParameters(HBASE_TABLE))

    // 将DataFrame转换为创建RDD[(ImmutableBytesWritable, Put)],封装数据
    val family: String = fromParameters(HBASE_FAMILY)
    val rowKey = fromParameters(HBASE_ROW_KEY)
    val columns: Array[String] = data.columns
    val putRDD: RDD[(ImmutableBytesWritable, Put)] = data.rdd.mapPartitions(
      it => {
        it.map(row => {
          val rowKeyValue = row.getAs[String](rowKey)
          val put = new Put(rowKeyValue)
          put.addColumn(family, rowKey, rowKeyValue)
          columns.map(field => {
            put.addColumn(family, field, row.getAs[String](field))
          })
          (new ImmutableBytesWritable(rowKeyValue), put)
        })
      }
    )

    // 使用rdd的HadoopApi写入数据
    putRDD.saveAsNewAPIHadoopFile(
      s"datas/hbase/output-${System.nanoTime()}", //
      classOf[ImmutableBytesWritable], //
      classOf[Put], //
      classOf[TableOutputFormat[ImmutableBytesWritable]], //
      conf)
  }

  def fromParameters(key:String) = {
    parameters.getOrElse(key,"")
  }
}

测试代码:

import com.lcy.models.Constants
import com.lcy.models.datasource.HBaseRelation
import com.lcy.models.utils.SparkUtils
import org.apache.hadoop.hbase.mapred.TableInputFormat
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.junit.Test

object SparkTest {
  val ZK_HOST_HBASE = "hbase.zookeeper.quorum"
  val ZK_PORT_HBASE = "hbase.zookeeper.property.clientPort"
  val HBASE_FAMILY ="family.name"
  val HBASE_ROW_KEY ="row.key.name"
  val HBASE_SELECT_FIELDS ="fields.selecting"
  val HBASE_TABLE ="hbase.table"

  /**
   * inType=hbase
zkHosts=bigdata-cdh01.itcast.cn
zkPort=2181
hbaseTable=tbl_users
family=detail
selectFieldNames=id,gender
   * @param args
   */

  def main(args: Array[String]): Unit = {
    // 1,创建sparkSession
    val spark: SparkSession = SparkUtils.sparkSession(SparkUtils.sparkConf(this.getClass.getCanonicalName))

    val frame: DataFrame = spark.read
      .format("com.lcy.models.datasource")
      .option(ZK_HOST_HBASE, "bigdata-cdh01.itcast.cn")
      .option(ZK_PORT_HBASE, 2181)
      .option(HBASE_TABLE, "tbl_users")
      .option(HBASE_FAMILY, "detail")
      .option(HBASE_SELECT_FIELDS, "id,gender")
      .option(HBASE_SELECT_WHERE_CONDITIONS,"id[ge]50")
      .load()

    frame.show()

    frame.write.format("hbase")
      .mode(SaveMode.Append)
      .option(ZK_HOST_HBASE, "bigdata-cdh01.itcast.cn")
      .option(ZK_PORT_HBASE, 2181)
      .option(HBASE_TABLE, "tbl_users_2")
      .option(HBASE_FAMILY, "detail")
      .option(HBASE_ROW_KEY, "id")
     
      .save()

    spark.stop()
  }


}

 

posted @ 2022-06-15 16:31  江东邮差  阅读(175)  评论(0编辑  收藏  举报