SparkSQL自定义函数
一:自定义函数分类
在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:
1.UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
2.UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等
3.UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap
二:自定义函数的使用UDF
(一)定义case class
case class Emp(empno:Int,ename:String,job:String,mgr:String,hiredate:String,sal:Int,comm:String,deptno:Int)
(二)导入emp.csv的文件
val lineRDD = sc.textFile("/emp.csv").map(_.split(","))
(三)生成DataFrame
val allEmp = lineRDD.map(x=>Emp(x(0).toInt,x(1),x(2),x(3),x(4),x(5).toInt,x(6),x(7).toInt)) val empDF = allEmp.toDF
(四)注册成一个临时视图
empDF.createOrReplaceTempView("emp")
(五)自定义一个函数,拼加字符串
spark.sqlContext.udf.register("concatstr",(s1:String,s2:String)=>s1+"***"+s2)
(六)调用自定义函数,将ename和job这两个字段拼接在一起
spark.sql("select concatstr(ename,job) from emp").show
三:用户自定义聚合函数UDAF,需要继承UserDefinedAggregateFunction类,并实现其中的8个方法
UDAF就是用户自定义聚合函数,比如平均值,最大最小值,累加,拼接等。这里以求平均数为例,并用Java实现
(一)实现自定义聚合函数
package SparkUDAF; import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.ArrayList; import java.util.List; public class MyAvg extends UserDefinedAggregateFunction { @Override public StructType inputSchema() { //输入数据的类型,输入的是字符串 List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField("InputData", DataTypes.StringType, true)); return DataTypes.createStructType(structFields); } @Override public StructType bufferSchema() { //聚合操作时,所处理的数据的数据类型,在这个例子里求平均数,要先求和(Sum),然后除以个数(Amount),所以这里需要处理两个字段 //注意因为用了ArrayList,所以是有序的 List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField("Amount", DataTypes.IntegerType, true)); structFields.add(DataTypes.createStructField("Sum", DataTypes.IntegerType, true)); return DataTypes.createStructType(structFields); } @Override public DataType dataType() { //UDAF计算后的返回值类型 return DataTypes.IntegerType; } @Override public boolean deterministic() { //判断输入和输出的类型是否一致,如果返回的是true则表示一致,false表示不一致,自行设置 return false; } @Override public void initialize(MutableAggregationBuffer buffer) { /* 对辅助字段进行初始化,就是上面定义的field1和field2 第一个辅助字段的下标为0,初始值为0 第二个辅助字段的下标为1,初始值为0 */ buffer.update(0, 0); buffer.update(1, 0); } @Override public void update(MutableAggregationBuffer buffer, Row input) { /* update可以认为是在每一个节点上都会对数据执行的操作,UDAF函数执行的时候,数据会被分发到每一个节点上,就是每一个分区 buffer.getInt(0)获取的是上一次聚合后的值,input就是当前获取的数据 */ //修改辅助字段的值,buffer.getInt(x)获取的是上一次聚合后的值,x表示 buffer.update(0, buffer.getInt(0) + 1); //表示某个数字的个数 buffer.update(1, buffer.getInt(1) + Integer.parseInt(input.getString(0))); //表示某个数字的总和 } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { /* merge:对每个分区的结果进行合并,每个分布式的节点上做完update之后就要做一个全局合并的操作 合并每一个update操作的结果,将各个节点上的数据合并起来 buffer1.getInt(0) : 上一次聚合后的值 buffer2.getInt(0) : 这次计算传入进来的update的结果 */ //对第一个字段Amount进行求和,求出总个数 buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0)); //对第二个字段Sum进行求和,求出总和 buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1)); } @Override public Object evaluate(Row buffer) { //表示最终计算的结果,第二个参数表示和值,第一个参数表示个数 return buffer.getInt(1) / buffer.getInt(0); } }
(二)注册并使用UDAF
package SparkUDAF; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.util.ArrayList; import java.util.List; public class TestMain { public static void main(String[] args) { SparkConf conf =new SparkConf(); conf.setMaster("local").setAppName("MyAvg"); JavaSparkContext sc= new JavaSparkContext(conf); //得到SQLContext对象 SQLContext sqlContext = new SQLContext(sc); //注册自定义函数 sqlContext.udf().register("my_avg",new MyAvg()); //读入数据 JavaRDD<String> lines = sc.textFile("d:\\test.txt"); //分词 JavaRDD<Row> rows=lines.map(line-> RowFactory.create(line.split("\\^"))); //定义schema的结构,a字段是字母,b字段是value List<StructField> structFields = new ArrayList<>(); structFields.add(DataTypes.createStructField("a",DataTypes.StringType,true)); structFields.add(DataTypes.createStructField("b",DataTypes.StringType,true)); StructType structType = DataTypes.createStructType(structFields); //创建DataFrame Dataset ds=sqlContext.createDataFrame(rows,structType); ds.registerTempTable("test"); //执行查询 sqlContext.sql("select a,my_avg(b) from test group by a").show(); sc.stop(); } }