Fork me on GitHub

【Spark篇】---SparkSql之UDF函数和UDAF函数

一、前述

SparkSql中自定义函数包括UDF和UDAF

UDF:一进一出  UDAF:多进一出 (联想Sum函数)

二、UDF函数

  UDF:用户自定义函数,user defined function

    * 根据UDF函数参数的个数来决定是实现哪一个UDF  UDF1,UDF2。。。。UDF1xxx
    * UDF1 传一个参数  UDF2传两个参数。。。。。

            sqlContext.udf().register("StrLen", new UDF1<String,Integer>() {

                private static final long serialVersionUID = 1L;

                @Override
                public Integer call(String t1) throws Exception {
                    return t1.length();
                }
            }, DataTypes.IntegerType);
            sqlContext.sql("select name ,StrLen(name) as length from user").show();
 sqlContext.udf().register("StrLen",new UDF2<String, Integer, Integer>() {

            /**
             *
             */
            private static final long serialVersionUID = 1L;

            @Override
            public Integer call(String t1, Integer t2) throws Exception {
                return t1.length()+t2;
            }
        } ,DataTypes.IntegerType );
        sqlContext.sql("select name ,StrLen(name,10) as length from user").show();

 三、UDAF函数

 UDAF:用户自定义聚合函数,user defined aggreagatefunction

package com.spark.sparksql.udf_udaf;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
/**
 * UDAF 用户自定义聚合函数
 * @author root
 *
 */
public class UDAF {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf();
        conf.setMaster("local").setAppName("udaf");
        JavaSparkContext sc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);
        JavaRDD<String> parallelize = sc.parallelize(
                Arrays.asList("zhangsan","lisi","wangwu","zhangsan","zhangsan","lisi"));
        JavaRDD<Row> rowRDD = parallelize.map(new Function<String, Row>() {

            /**
             * 
             */
            private static final long serialVersionUID = 1L;

            @Override
            public Row call(String s) throws Exception {
                return RowFactory.create(s);
            }
        });
        
        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        StructType schema = DataTypes.createStructType(fields);
        DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
        df.registerTempTable("user");
        /**
         * 注册一个UDAF函数,实现统计相同值得个数
         * 注意:这里可以自定义一个类继承UserDefinedAggregateFunction类也是可以的
         */
        sqlContext.udf().register("StringCount",new UserDefinedAggregateFunction() {
            
            /**
             * 
             */
            private static final long serialVersionUID = 1L;
            
            /**
             * 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果
             */
            @Override
            public void initialize(MutableAggregationBuffer buffer) {
                buffer.update(0, 0);
            }
            
            /**
             * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
             * buffer.getInt(0)获取的是上一次聚合后的值
             * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合 
             * 大聚和发生在reduce端.
             * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
             */
            @Override
            public void update(MutableAggregationBuffer buffer, Row arg1) {
                buffer.update(0, buffer.getInt(0)+1);
                
            }
            /**
             * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
             * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
             * buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值       
             * buffer2.getInt(0) : 这次计算传入进来的update的结果
             * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
             * 也可以是一个节点里面的多个executor合并 reduce端大聚合
             */
            @Override
            public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
                buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
            }
            /**
             * 在进行聚合操作的时候所要处理的数据的结果的类型
             */
            @Override
            public StructType bufferSchema() {
                return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("bffer111", DataTypes.IntegerType, true)));
            }
            /**
             * 最后返回一个和DataType的类型要一致的类型,返回UDAF最后的计算结果
             */
            @Override
            public Object evaluate(Row row) {
                return row.getInt(0);
            }
            /**
             * 指定UDAF函数计算后返回的结果类型
             */
            @Override
            public DataType dataType() {
                return DataTypes.IntegerType;
            }
            /**
             * 指定输入字段的字段及类型
             */
            @Override
            public StructType inputSchema() {
                return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("nameeee", DataTypes.StringType, true)));
            }
            /**
             * 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。
             */
            @Override
            public boolean deterministic() {
                return true;
            }
            
        });
        
        sqlContext.sql("select name ,StringCount(name) as strCount from user group by name").show();
        
        
        sc.stop();
    }
}

 

传入到UDAF中的数据必须在分组字段里面,相当于是一组数据进来。

posted @ 2018-03-07 19:32  L先生AI课堂  阅读(5224)  评论(0编辑  收藏  举报