Spark开发-Spark UDAF(三)

功能

 在类中实现公用的代码

示例

import org.roaringbitmap.buffer.ImmutableRoaringBitmap;
import org.roaringbitmap.buffer.MutableRoaringBitmap;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;

public class RBitMapBuffer  {

    private MutableRoaringBitmap bitMap;

    /**
     * 构造函数
     */
    public RBitMapBuffer(){
        bitMap = new MutableRoaringBitmap();
    }

    /**
     * 带参数的构造函数
     */
    public RBitMapBuffer(byte[] buffer){
        bitMap =  new ImmutableRoaringBitmap(ByteBuffer.wrap(buffer)).toMutableRoaringBitmap();
    }

    /**
     * 加入
     */
    public void addItem(int id) {
        bitMap.add(id);
    }

    /**
     * 进行merge
     */
    public void merge( byte[] buffer) throws IOException {
        if (buffer == null) {
            return;
        }
        // ByteBuffer其实就是对byte数组的一种封装,所以可以使用静态方法wrap(byte[] data)手动封装数组
        //默认bitMap 非空,所以要用 带参数的构造函数
        ImmutableRoaringBitmap other = new ImmutableRoaringBitmap(ByteBuffer.wrap(buffer));
        bitMap.or( other.toMutableRoaringBitmap());
    }
    
    /**
     *
     * @return 返回二进制数组
     * @throws IOException
     */
    public byte[] getPartial() throws IOException {
        if (bitMap == null) {
            return null;
        }
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        DataOutputStream ndos = new DataOutputStream(bos);
        bitMap.serialize(ndos);
        ndos.close();
        return bos.toByteArray();
    }

    public int  getCardinalityCount(){
        return bitMap.getCardinality();
    }
    
    public void reset() {
        bitMap.clear();
    }
      
}

UDAF

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.io.*;
import java.util.ArrayList;
import java.util.List;

public class RBitmapBufferUDAF  extends UserDefinedAggregateFunction {
    /**
     * // 聚合函数的输入数据结构
     */
    @Override
    public StructType inputSchema() {
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("field", DataTypes.IntegerType, true));
        return DataTypes.createStructType(structFields);
    }

    /**
     * 聚缓存区数据结构   //聚合的中间过程中产生的数据的数据类型定义
     */
    @Override
    public StructType bufferSchema() {
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("field", DataTypes.BinaryType, true));
        return DataTypes.createStructType(structFields);
    }

    /**
     * 聚合函数返回值数据结构
     */
    @Override
    public DataType dataType() {
        return DataTypes.LongType;
    }

    /**
     * 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
     */
    @Override
    public boolean deterministic() {
        //是否强制每次执行的结果相同
        return true;
    }

    /**
     * 初始化缓冲区
     */
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        //初始化
        buffer.update(0, null);
    }

    /**
     *  给聚合函数传入一条新数据进行处理
     *  buffer.getInt(0)获取的是上一次聚合后的值
     *   //用输入数据input更新buffer值,类似于combineByKey
     */

    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        // 相同的executor间的数据合并
        Object in = input.get(0);
        Object out = buffer.get(0);
        RBitMapBuffer bitMapAgg = new RBitMapBuffer();
        // 1. 输入为空直接返回不更新
        if(in == null){
            return ;
        }
        // 2. 源为空则直接更新值为输入
        int inInt = Integer.valueOf(in.toString());
        byte[] inBytes = null ;
        if(out == null){
            bitMapAgg.addItem(inInt);
            try{
                inBytes = bitMapAgg.getPartial();
            }   catch (IOException e) {
                e.printStackTrace();
            }
            buffer.update(0, inBytes);
            return ;
        }
        // 3. 源和输入都不为空使用 bitmap去重合并
        byte[] outBytes = (byte[]) buffer.get(0);
        byte[] result = outBytes;
        RBitMapBuffer bitresultMapAgg = new RBitMapBuffer(outBytes);
        try {
            bitresultMapAgg.addItem(inInt);
            result =  bitresultMapAgg.getPartial();
        } catch (IOException e) {
            e.printStackTrace();
        }
        buffer.update(0, result);
    }


    /**
     *  合并聚合函数缓冲区
     *      //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
     *    //这里要注意该方法没有返回值,
     *    在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
     */
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        //不同excutor间的数据合并
        // 合并两个聚合buffer,该函数在聚合并两个部分聚合数据集的时候调用
        //update(buffer1, buffer2);
        Object out = buffer1.get(0);
        byte[] outBitBytes = (byte[]) out;
        byte[] resultBit = outBitBytes;
        byte[] inBytes = (byte[]) buffer2.get(0);
        if(out == null){
            buffer1.update(0, inBytes);
            return ;
        }
        if (out != null) {
            try {
                RBitMapBuffer bitMapAgg = new RBitMapBuffer(outBitBytes);
                bitMapAgg.merge(inBytes);
                resultBit = bitMapAgg.getPartial();
            } catch (IOException e) {
                e.printStackTrace();
            }
            buffer1.update(0, resultBit);
        }
    }

    /**
     * 计算最终结果
     */

    @Override
    public Object evaluate(Row buffer) {
        //根据Buffer计算结果
        byte[] netResult = null;
        long r = 0L;
        Object val = buffer.get(0);
        if (val != null) {
            try{
                RBitMapBuffer rr = new RBitMapBuffer((byte[]) val);
                netResult =  rr.getPartial();
                r = rr.getCardinalityCount();
              } catch (IOException e) {
            e.printStackTrace();
               }
        }
        // return netResult  ;
        return r;
    }
}

参考

 hive udf 读写存储到hbase的roaringbitmap https://blog.csdn.net/qq_34748569/article/details/105252559
 https://github.com/sunyaf/bitmapudf
 sparkSQL自定义聚合函数(UDAF)实现bitmap函数 https://blog.csdn.net/xiongbingcool/article/details/81282118
posted @ 2020-11-21 12:54  辰令  阅读(54)  评论(0编辑  收藏  举报