【Spark-SQL学习之三】 UDF、UDAF、开窗函数

环境
  虚拟机:VMware 10
  Linux版本:CentOS-6.5-x86_64
  客户端:Xshell4
  FTP:Xftp4
  jdk1.8
  scala-2.10.4(依赖jdk1.8)
  spark-1.6


一、UDF:用户自定义函数。
可以自定义类实现UDFX接口

示例代码:
Java:

package com.wjy.df;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class UDF {

    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setMaster("local").setAppName("UDF");
        JavaSparkContext sc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);
        
        JavaRDD<String> rdd = sc.parallelize(Arrays.asList("xiaoming","xiaohong","xiaolei"));
        JavaRDD<Row> rdd2 = rdd.map(new Function<String, Row>() {
            private static final long serialVersionUID = 1L;
            @Override
            public Row call(String str) throws Exception {
                return RowFactory.create(str);
            }
        });
        
        /**
         * 动态创建Schema方式加载DF
         */
        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        StructType schema = DataTypes.createStructType(fields);
        DataFrame dataFrame = sqlContext.createDataFrame(rdd2, schema);
        dataFrame.registerTempTable("user");
        
        //定义一个统计字符串长度的函数
        /**
         * 根据UDF函数参数的个数来决定是实现哪一个UDF  UDF1,UDF2。。。。UDF1xxx
         */
        sqlContext.udf().register("StrLen", new UDF1<String, Integer>() {
            private static final long serialVersionUID = 1L;

            @Override
            public Integer call(String str) throws Exception {
                return str.length();
            }
        },DataTypes.IntegerType);
        sqlContext.sql("select name ,StrLen(name) as length from user").show();
        /*
         * +--------+------+
           |    name|length|
           +--------+------+
           |xiaoming|     8|
           |xiaohong|     8|
            | xiaolei|     7|
           +--------+------+
         */
        
        sqlContext.udf().register("StrLen2", new UDF2<String, Integer, Integer>() {
            private static final long serialVersionUID = 1L;

            @Override
            public Integer call(String str, Integer num) throws Exception {
                return str.length()+num;
            }
        }, DataTypes.IntegerType);
        sqlContext.sql("select name ,StrLen2(name,10) as length from user").show();
        /*
         * +--------+------+
           |    name|length|
           +--------+------+
           |xiaoming|    18|
             |xiaohong|    18|
           | xiaolei|    17|
           +--------+------+
         */
        
        sc.stop();
    }

}

 

Scala:

package com.wjy.df

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.RowFactory
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.SQLContext

object UDF {
  def main(args:Array[String]):Unit={
    val conf = new SparkConf().setMaster("local").setAppName("");
    val sc = new SparkContext(conf);
    val sqlContext = new SQLContext(sc);
    val rdd = sc.makeRDD(Array("zhansan","lisi","wangwu"));
    val row = rdd.map(x=>{
      RowFactory.create(x);
    });
    val schema = DataTypes.createStructType(Array(StructField("name",StringType,true)));
    val df = sqlContext.createDataFrame(row, schema);
    df.show;//show方法可以没有()
    df.registerTempTable("user");
    
    //StrLen
    sqlContext.udf.register("StrLen", (s:String)=>{s.length()});
    sqlContext.sql("select name ,StrLen(name) as length from user").show;
    
    //StrLen2
    sqlContext.udf.register("StrLen2", (s:String,i:Integer)=>{s.length()+i});
    sqlContext.sql("select name ,StrLen2(name,10) as length from user").show;
    
    sc.stop();
  }
}

 

二、UDAF:用户自定义聚合函数。
实现UDAF函数如果要自定义类要继承UserDefinedAggregateFunction类

示例代码:
Java:

 

package com.wjy.df;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

/**
 * UDAF 用户自定义聚合函数
 * @author root
 *
 */
public class UDAF {

    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setMaster("local").setAppName("UDAF");
        JavaSparkContext sc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);
        JavaRDD<String> parallelize = sc.parallelize(
                Arrays.asList("zhangsan","lisi","wangwu","zhangsan","zhangsan","lisi"));
        JavaRDD<Row> rowRDD = parallelize.map(new Function<String, Row>() {
            private static final long serialVersionUID = 1L;
            @Override
            public Row call(String s) throws Exception {
                return RowFactory.create(s);
            }
        });
        
        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        StructType schema = DataTypes.createStructType(fields);
        DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
        df.registerTempTable("user");
        
        /**
         * 注册一个UDAF函数,实现统计相同值得个数
         * 注意:这里可以自定义一个类继承UserDefinedAggregateFunction类也是可以的
         */
        sqlContext.udf().register("StringCount",new UserDefinedAggregateFunction(){
            private static final long serialVersionUID = 1L;

            /**
             * 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果
             */
            @Override
            public void initialize(MutableAggregationBuffer buffer) {
                buffer.update(0, 0);
            }
            
            /**
             * 指定输入字段的字段及类型
             */
            @Override
            public StructType inputSchema() {
                return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("name", DataTypes.StringType, true)));
            }
            
            /**
             * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
             * buffer.getInt(0)获取的是上一次聚合后的值
             * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合 
             * 大聚和发生在reduce端.
             * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
             */
            @Override
            public void update(MutableAggregationBuffer buffer, Row arg1) {
                buffer.update(0, buffer.getInt(0)+1);
            }
            
            /**
             * 在进行聚合操作的时候所要处理的数据的结果的类型
             */
            @Override
            public StructType bufferSchema() {
                return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("buffer", DataTypes.IntegerType, true)));
            }

            /**
             * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
             * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
             * buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值       
             * buffer2.getInt(0) : 这次计算传入进来的update的结果
             * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
             */
            @Override
            public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
                buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
            }
            
            /**
             * 指定UDAF函数计算后返回的结果类型
             */
            @Override
            public DataType dataType() {
                return DataTypes.IntegerType;
            }

            /**
             * 最后返回一个和dataType方法的类型要一致的类型,返回UDAF最后的计算结果
             */
            @Override
            public Object evaluate(Row row) {
                return row.getInt(0);
            }
            
            /**
             * 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。
             */
            @Override
            public boolean deterministic() {
                return true;
            }

            });
        
        sqlContext.sql("select name ,StringCount(name) as strCount from user group by name").show();
        sc.stop();
    }

}

 

Scala:

 

package com.wjy.df

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.RowFactory
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.DataType

class MyUDAF extends UserDefinedAggregateFunction{ 
  // 为每个分组的数据执行初始化值
  def initialize(buffer: MutableAggregationBuffer): Unit = {
     buffer(0) = 0
  }
  
  //输入数据的类型
  def inputSchema: StructType = {
    DataTypes.createStructType(Array(DataTypes.createStructField("input", StringType, true)))
  }
  
  // 每个组,有新的值进来的时候,进行分组对应的聚合值的计算
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Int](0)+1
  }
  
  // 聚合操作时,所处理的数据的类型
  def bufferSchema: StructType = {
    DataTypes.createStructType(Array(DataTypes.createStructField("aaa", IntegerType, true)))
  }
  
  //最后merger的时候,在各个节点上的聚合值,要进行merge,也就是合并
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0)+buffer2.getAs[Int](0) 
  }
  
  // 最终函数返回值的类型
  def dataType: DataType = {
    DataTypes.IntegerType
  }
  
  // 最后返回一个最终的聚合值     要和dataType的类型一一对应
  def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0)
  }

  //保持一致性
  def deterministic: Boolean = {
    true
  }
}

object UDAF {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    conf.setMaster("local").setAppName("udaf")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.makeRDD(Array("zhangsan","lisi","wangwu","zhangsan","lisi"))
    val rowRDD = rdd.map { x => {RowFactory.create(x)} }
    
    val schema = DataTypes.createStructType(Array(DataTypes.createStructField("name", StringType, true)))
    val df = sqlContext.createDataFrame(rowRDD, schema)
    df.show()
    df.registerTempTable("user")
    /**
     * 注册一个udaf函数
     */
    sqlContext.udf.register("StringCount", new MyUDAF())
    sqlContext.sql("select name ,StringCount(name) as count from user group by name").show()
    sc.stop()
  }
}

 

三、开窗函数
开窗函数格式:
row_number() over (partitin by XXX order by XXX)
注意:
row_number() 开窗函数是按照某个字段分组,然后取另一字段的前几个的值,相当于分组取topN;
如果SQL语句里面使用到了开窗函数,那么这个SQL语句必须使用HiveContext来执行,HiveContext默认情况下在本地无法创建。

示例代码:
Java:

package com.wjy.df;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.hive.HiveContext;

public class RowNumberWindowFun {

    public static void main(String[] args) {
        SparkConf conf = new SparkConf();
        conf.setAppName("windowfun");
        conf.set("spark.sql.shuffle.partitions","1");
        JavaSparkContext sc = new JavaSparkContext(conf);
        HiveContext hiveContext = new HiveContext(sc);
        hiveContext.sql("use spark");
        hiveContext.sql("drop table if exists sales");
        hiveContext.sql("create table if not exists sales (riqi string,leibie string,jine Int) "
                + "row format delimited fields terminated by '\t'");
        hiveContext.sql("load data local inpath '/root/test/sales' into table sales");
        
        /**
         * 开窗函数格式:
         * 【 row_number() over (partition by XXX order by XXX DESC) as rank】
         * 注意:rank 从1开始
         */
        /**
         * 以类别分组,按每种类别金额降序排序,显示 【日期,种类,金额】 结果,如:
         * 
         * 1 A 100
         * 2 B 200
         * 3 A 300
         * 4 B 400
         * 5 A 500
         * 6 B 600
         * 排序后:
         * 5 A 500  --rank 1
         * 3 A 300  --rank 2 
         * 1 A 100  --rank 3
         * 6 B 600  --rank 1
         * 4 B 400    --rank 2
         * 2 B 200  --rank 3
         * 
         */
        DataFrame result = hiveContext.sql("select riqi,leibie,jine "
                            + "from ("
                                + "select riqi,leibie,jine,"
                                + "row_number() over (partition by leibie order by jine desc) rank "
                                + "from sales) t "
                        + "where t.rank<=3");
        result.show(100);
        /**
         * 将结果保存到hive表sales_result
         */
        result.write().mode(SaveMode.Overwrite).saveAsTable("sales_result");
        sc.stop();
    }
}

 

Scala:

 

package com.wjy.df

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.hive.HiveContext

object RowNumberWindowFun {
  val conf = new SparkConf()
    conf.setAppName("windowfun")
    val sc = new SparkContext(conf)
    val hiveContext = new HiveContext(sc)
    hiveContext.sql("use spark");
        hiveContext.sql("drop table if exists sales");
        hiveContext.sql("create table if not exists sales (riqi string,leibie string,jine Int) "
                + "row format delimited fields terminated by '\t'");
        hiveContext.sql("load data local inpath '/root/test/sales' into table sales");
        
        /**
         * 开窗函数格式:
         * 【 rou_number() over (partitin by XXX order by XXX) 】
         */
        val result = hiveContext.sql("select riqi,leibie,jine "
                            + "from ("
                                + "select riqi,leibie,jine,"
                                + "row_number() over (partition by leibie order by jine desc) rank "
                                + "from sales) t "
                        + "where t.rank<=3");
        result.show();
    sc.stop()
}

 

参考:
Spark

posted @ 2019-04-16 16:15  cac2020  阅读(843)  评论(0编辑  收藏  举报