自定义spark UDAF
Spark提供了两种自定义聚合函数的方法,分别如下:
Untyped User-Defined Aggregate Functions
有类型的自定义聚合函数,主要适用于 DataSet
Type-Safe User-Defined Aggregate Functions
无类型的自定义聚合函数,主要适用于 DataFrame
无类型的自定义聚合函数样例代码:
import java.util.ArrayList; import java.util.List; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public static class MyAverage extends UserDefinedAggregateFunction { private StructType inputSchema; private StructType bufferSchema; public MyAverage() { List<StructField> inputFields = new ArrayList<>(); inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true)); inputSchema = DataTypes.createStructType(inputFields); List<StructField> bufferFields = new ArrayList<>(); bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true)); bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true)); bufferSchema = DataTypes.createStructType(bufferFields); } // Data types of input arguments of this aggregate function public StructType inputSchema() { return inputSchema; } // Data types of values in the aggregation buffer public StructType bufferSchema() { return bufferSchema; } // The data type of the returned value public DataType dataType() { return DataTypes.DoubleType; } // Whether this function always returns the same output on the identical 相同的 input public boolean deterministic() { return true; } // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to // standard methods like retrieving 获取 a value at an index (e.g., get(), getBoolean()), provides // the opportunity 方式 to update its values. Note that arrays and maps inside the buffer are still // immutable 不可变的. public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, 0L); buffer.update(1, 0L); } // Updates the given aggregation buffer `buffer` with new input data from `input` public void update(MutableAggregationBuffer buffer, Row input) { if (!input.isNullAt(0)) { long updatedSum = buffer.getLong(0) + input.getLong(0); long updatedCount = buffer.getLong(1) + 1; buffer.update(0, updatedSum); buffer.update(1, updatedCount); } } // Merges two aggregation buffers and stores the updated buffer values back to `buffer1` public void merge(MutableAggregationBuffer buffer1, Row buffer2) { long mergedSum = buffer1.getLong(0) + buffer2.getLong(0); long mergedCount = buffer1.getLong(1) + buffer2.getLong(1); buffer1.update(0, mergedSum); buffer1.update(1, mergedCount); } // Calculates the final result public Double evaluate(Row buffer) { return ((double) buffer.getLong(0)) / buffer.getLong(1); } } // Register the function to access it spark.udf().register("myAverage", new MyAverage()); Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json"); df.createOrReplaceTempView("employees"); df.show(); // +-------+------+ // | name|salary| // +-------+------+ // |Michael| 3000| // | Andy| 4500| // | Justin| 3500| // | Berta| 4000| // +-------+------+ Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees"); result.show(); // +--------------+ // |average_salary| // +--------------+ // | 3750.0| // +--------------+
样例代码2:
import java.util.Arrays; import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; /** * 组内拼接去重函数(group_concat_distinct()) */ public class GroupConcatDistinctUDAF extends UserDefinedAggregateFunction { private static final long serialVersionUID = -2510776241322950505L; // 指定输入数据的字段与类型 // 指定具体的输入数据的类型 // * 自段名称随意:Users can choose names to identify the input arguments - 这里可以是“name”,或者其他任意串 private StructType inputSchema = DataTypes.createStructType(Arrays.asList( DataTypes.createStructField("cityInfo", DataTypes.StringType, true))); // 指定缓冲数据的字段与类型 // 在进行聚合操作的时候所要处理的数据的中间结果类型 private StructType bufferSchema = DataTypes.createStructType(Arrays.asList( DataTypes.createStructField("bufferCityInfo", DataTypes.StringType, true))); // 指定返回类型 private DataType dataType = DataTypes.StringType; // 指定是否是确定性的 /*whether given the same input, * always return the same output * true: yes*/ private boolean deterministic = true; @Override public StructType inputSchema() { return inputSchema; } @Override public StructType bufferSchema() { return bufferSchema; } @Override public DataType dataType() { return dataType; } @Override public boolean deterministic() { return deterministic; } /** * 初始化 * 可以认为是,你自己在内部指定一个初始的值 * Initializes the given aggregation buffer */ @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, ""); } /** * 更新 * 可以认为是,一个一个地将组内的字段值传递进来 * 实现拼接的逻辑 * * 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算 * 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner */ @Override public void update(MutableAggregationBuffer buffer, Row input) { // 缓冲中的已经拼接过的城市信息串 String bufferCityInfo = buffer.getString(0); // 刚刚传递进来的某个城市信息 String cityInfo = input.getString(0); // 在这里要实现去重的逻辑 // 判断:之前没有拼接过某个城市信息,那么这里才可以接下去拼接新的城市信息 if(!bufferCityInfo.contains(cityInfo)) { if("".equals(bufferCityInfo)) { bufferCityInfo += cityInfo; } else { // 比如1:北京 // 1:北京,2:上海 bufferCityInfo += "," + cityInfo; } buffer.update(0, bufferCityInfo); } } /** * 合并 * update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 * 但是可能一个分组内的数据,会分布在多个节点上处理 * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来 */ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { String bufferCityInfo1 = buffer1.getString(0); String bufferCityInfo2 = buffer2.getString(0); for(String cityInfo : bufferCityInfo2.split(",")) { if(!bufferCityInfo1.contains(cityInfo)) { if("".equals(bufferCityInfo1)) { bufferCityInfo1 += cityInfo; } else { bufferCityInfo1 += "," + cityInfo; } } } buffer1.update(0, bufferCityInfo1); } @Override public Object evaluate(Row row) { return row.getString(0); } }
有类型的自定义聚合函数,样例代码:
import java.io.Serializable; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.TypedColumn; import org.apache.spark.sql.expressions.Aggregator; public static class Employee implements Serializable { private String name; private long salary; // Constructors, getters, setters... } public static class Average implements Serializable { private long sum; private long count; // Constructors, getters, setters... } public static class MyAverage extends Aggregator<Employee, Average, Double> { // A zero value for this aggregation. Should satisfy the property that any b + zero = b public Average zero() { return new Average(0L, 0L); } // Combine two values to produce a new value. For performance, the function may modify `buffer` // and return it instead of constructing a new object public Average reduce(Average buffer, Employee employee) { long newSum = buffer.getSum() + employee.getSalary(); long newCount = buffer.getCount() + 1; buffer.setSum(newSum); buffer.setCount(newCount); return buffer; } // Merge two intermediate values public Average merge(Average b1, Average b2) { long mergedSum = b1.getSum() + b2.getSum(); long mergedCount = b1.getCount() + b2.getCount(); b1.setSum(mergedSum); b1.setCount(mergedCount); return b1; } // Transform the output of the reduction public Double finish(Average reduction) { return ((double) reduction.getSum()) / reduction.getCount(); } // Specifies the Encoder for the intermediate value type public Encoder<Average> bufferEncoder() { return Encoders.bean(Average.class); } // Specifies the Encoder for the final output value type public Encoder<Double> outputEncoder() { return Encoders.DOUBLE(); } } Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class); String path = "examples/src/main/resources/employees.json"; Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder); ds.show(); // +-------+------+ // | name|salary| // +-------+------+ // |Michael| 3000| // | Andy| 4500| // | Justin| 3500| // | Berta| 4000| // +-------+------+ MyAverage myAverage = new MyAverage(); // Convert the function to a `TypedColumn` and give it a name TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary"); Dataset<Double> result = ds.select(averageSalary); result.show(); // +--------------+ // |average_salary| // +--------------+ // | 3750.0| // +--------------+
相关API