Flink 操作 —— 计算函数

一、窗口流 WindowedStream

通常由 keyedStream + windowAssigner函数生成。运行时将KeyedStream 和窗口上的操作合并为一个操作。

 

aggregate

用于按字段或者按位置(元组)对流聚合/分组

  private def aggregate(aggregationType: AggregationType, field: String): DataStream[T] = {
    val position = fieldNames2Indices(getInputType(), Array(field))(0)
    aggregate(aggregationType, position)
  }

  def aggregate(aggregationType: AggregationType, position: Int): DataStream[T] = {

    val jStream = javaStream.asInstanceOf[JavaWStream[Product, K, W]]

    val reducer = aggregationType match {
      case AggregationType.SUM =>
        new SumAggregator(position, jStream.getInputType, jStream.getExecutionEnvironment.getConfig)

      case _ =>
        new ComparableAggregator(
          position,
          jStream.getInputType,
          aggregationType,
          true,
          jStream.getExecutionEnvironment.getConfig)
    }

    new DataStream[Product](jStream.reduce(reducer)).asInstanceOf[DataStream[T]]
  }

聚合函数抽象类

public abstract class AggregationFunction<T> implements ReduceFunction<T> {
	private static final long serialVersionUID = 1L;

	/**
	 * Aggregation types that can be used on a windowed stream or keyed stream.
	 */
	public enum AggregationType {
		SUM, MIN, MAX, MINBY, MAXBY,
	}
}

该抽象类的实现有 SumAggregator,ComparableAggregator

如 SumAggregator,实现了 ReduceFunction 接口

public class SumAggregator<T> extends AggregationFunction<T> {

	private static final long serialVersionUID = 1L;

	private final FieldAccessor<T, Object> fieldAccessor;
	private final SumFunction adder;
	private final TypeSerializer<T> serializer;
	private final boolean isTuple;

	public SumAggregator(int pos, TypeInformation<T> typeInfo, ExecutionConfig config) {
		fieldAccessor = FieldAccessorFactory.getAccessor(typeInfo, pos, config);
		adder = SumFunction.getForClass(fieldAccessor.getFieldType().getTypeClass());
		if (typeInfo instanceof TupleTypeInfo) {
			isTuple = true;
			serializer = null;
		} else {
			isTuple = false;
			this.serializer = typeInfo.createSerializer(config);
		}
	}

	public SumAggregator(String field, TypeInformation<T> typeInfo, ExecutionConfig config) {
		fieldAccessor = FieldAccessorFactory.getAccessor(typeInfo, field, config);
		adder = SumFunction.getForClass(fieldAccessor.getFieldType().getTypeClass());
		if (typeInfo instanceof TupleTypeInfo) {
			isTuple = true;
			serializer = null;
		} else {
			isTuple = false;
			this.serializer = typeInfo.createSerializer(config);
		}
	}

	@Override
	@SuppressWarnings("unchecked")
	public T reduce(T value1, T value2) throws Exception {
		if (isTuple) {
			Tuple result = ((Tuple) value1).copy();
			return fieldAccessor.set((T) result, adder.add(fieldAccessor.get(value1), fieldAccessor.get(value2)));
		} else {
			T result = serializer.copy(value1);
			return fieldAccessor.set(result, adder.add(fieldAccessor.get(value1), fieldAccessor.get(value2)));
		}
	}
}

常见派生方法

sum

def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, position)

maxBy

def maxBy(position: Int): DataStream[T] = aggregate(AggregationType.MAXBY, position)

示例代码片段

sum

val counts: DataStream[(String, Int)] = text.flatMap(_.toLowerCase().split("\\W+"))
  .filter(_.nonEmpty)
  .map((_, 1))
  .keyBy(0)
  .countWindow(windowSize, slideSize)
  .sum(1)

maxBy

val counts: DataStream[(String, Int)] = text.flatMap(_.toLowerCase().split("\\W+"))
  .filter(_.nonEmpty)
  .map((_, 1))
  .keyBy(0)
  .countWindow(windowSize, slideSize)
  .maxBy(1)

 

reduce

同样的是分别针对某个 key 的元素集应用该处理函数

def reduce(function: (T, T) => T): DataStream[T] = {
    if (function == null) {
      throw new NullPointerException("Reduce function must not be null.")
    }
    val cleanFun = clean(function)
    val reducer = new ScalaReduceFunction[T](cleanFun)
    reduce(reducer)
  }

同样是实现了 ReduceFunction 接口

final class ScalaReduceFunction[T](private[this] val function: (T, T) => T)
    extends ReduceFunction[T] {
  
  @throws(classOf[Exception])
  override def reduce(a: T, b: T): T = {
    function(a, b)
  }
}

示例代码片段

stream.keyBy(0)
  .timeWindow(Time.of(2500, TimeUnit.MILLISECONDS), Time.of(500, TimeUnit.MILLISECONDS))
  .reduce((value1, value2) => (value1._1, value1._2 + value2._2))
  .addSink(new SinkFunction[(Long, Long)] {})

 

process

代码示例

val countsPerThirtySecs = sensorData
    .keyBy(_.id)
    // a custom window assigner for 30 second tumbling windows
    .window(new ThirtySecondsWindows)
    // a custom trigger that fires early (at most) every second
    .trigger(new OneSecondIntervalTrigger)
    // count readings per window
    .process(new CountFunction)

...

class CountFunction
    extends ProcessWindowFunction[SensorReading, (String, Long, Long, Int), String, TimeWindow] {

  override def process(
      key: String,
      ctx: Context,
      readings: Iterable[SensorReading],
      out: Collector[(String, Long, Long, Int)]): Unit = {

    // count readings
    val cnt = readings.count(_ => true)
    // get current watermark
    val evalTime = ctx.currentWatermark
    // emit result
    out.collect((key, ctx.window.getEnd, evalTime, cnt))
  }
}

 

apply

代码示例

val avgTemp: DataStream[SensorReading] = sensorData
    // convert Fahrenheit to Celsius using an inlined map function
    .map( r =>
    SensorReading(r.id, r.timestamp, (r.temperature - 32) * (5.0 / 9.0)) )
    // organize stream by sensorId
    .keyBy(_.id)
    // group readings in 1 second windows
    .timeWindow(Time.seconds(1))
    // compute average temperature using a user-defined function
    .apply(new TemperatureAverager)

...

class TemperatureAverager extends WindowFunction[SensorReading, SensorReading, String, TimeWindow] {

  /** apply() is invoked once for each window */
  override def apply(
    sensorId: String,
    window: TimeWindow,
    vals: Iterable[SensorReading],
    out: Collector[SensorReading]): Unit = {

    // compute the average temperature
    val (cnt, sum) = vals.foldLeft((0, 0.0))((c, r) => (c._1 + 1, c._2 + r.temperature))
    val avgTemp = sum / cnt

    // emit a SensorReading with the average temperature
    out.collect(SensorReading(sensorId, window.getEnd, avgTemp))
  }
}

 

二、常见计算函数

flatMap

示例(dataStream flatMap

//方式 ①
val splitIds = sensorIds.flatMap(id => id.split("_"))
//方法签名,返回值是 DataStream[R], 即算子的返回值 
def flatMap[R: TypeInformation](fun: T => TraversableOnce[R]): DataStream[R]

//方式 ②
//实现 FlatMapFunction 接口,方法返回值是 Unit,数据通过 collector 收集
class SplitIdFlatMap extends FlatMapFunction[String, String] {

  override def flatMap(id: String, collector: Collector[String]): Unit = id.split("_")

}

 

对比 keyedStream.sum 与 windowedStream.sum

object RollingSum {

  def main(args: Array[String]): Unit = {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    val inputSteam = env.fromElements(
      (1, 2, 2), (2, 3, 1), (2, 2, 4), (1, 5, 3)
    )
    val resultStream = inputSteam
      .keyBy(0)
      .sum(1)

    resultStream.print()
    //6> (1,2,2)
    //8> (2,3,1)
    //6> (1,7,2)
    //8> (2,5,1)
    env.execute("Rolling sum example")
  }

}

相当于

val resultStream = inputSteam
  .keyBy(0)
  .countWindow(Int.MaxValue, 1)
  .sum(1)

 

keyedStream.process

示例

object ProcessFunctionTimers {

  def main(args: Array[String]): Unit = {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
    val readings = env.addSource(new SensorSource)
    
    val warnings = readings.keyBy(_.id)
      // 按 key 分组后对数据再处理
      .process(new TempIncreaseAlertFunction)
    
    warnings.print()
    env.execute("Monitor sensor temperatures.")
  }
}

 KeyedProcessFunction 示例

class TempIncreaseAlertFunction extends KeyedProcessFunction[String, SensorReading, String] {
  lazy val lastTemp: ValueState[Double] = getRuntimeContext.getState(
    new ValueStateDescriptor[Double]("lastTemp", Types.of[Double]))

  lazy val currentTimer: ValueState[Long] = getRuntimeContext.getState(
    new ValueStateDescriptor[Long]("timer", Types.of[Long]))

  override def processElement(value: SensorReading,
                              ctx: KeyedProcessFunction[String, SensorReading, String]#Context,
                              out: Collector[String]): Unit = {
    val prevTemp = lastTemp.value()
    lastTemp.update(value.temperature)

    val curTimerTimestamp = currentTimer.value()
    if (prevTemp == 0.0) {

    } else if (value.temperature < prevTemp) {
      ctx.timerService().deleteProcessingTimeTimer(curTimerTimestamp)
      currentTimer.clear()
    } else if (value.temperature > prevTemp && curTimerTimestamp == 0) {
      val timerTs = ctx.timerService().currentProcessingTime() + 1000
      ctx.timerService().registerProcessingTimeTimer(timerTs)
      currentTimer.update(timerTs)
    }
  }

  override def onTimer(
                        ts: Long,
                        ctx: KeyedProcessFunction[String, SensorReading, String]#OnTimerContext,
                        out: Collector[String]): Unit = {

    out.collect("Temperature of sensor '" + ctx.getCurrentKey +
      "' monotonically increased for 1 second.")
    // reset current timer
    currentTimer.clear()
  }
}

 

三、RichFunction 与 processFunction

RichFunction

主要有3个方法 open,close,getRuntimeContext

public interface RichFunction extends Function {
  
  void open(Configuration parameters) throws Exception;

  void close() throws Exception;

  RuntimeContext getRuntimeContext();

    ...
}

 

KeyedProcessFunction

主要有 2 个方法, processElement,onTimer

public abstract class KeyedProcessFunction<K, I, O> extends AbstractRichFunction {

    public abstract void processElement(I value, Context ctx, Collector<O> out) throws Exception;

    public void onTimer(long timestamp, OnTimerContext ctx, Collector<O> out) throws Exception {
    }
    ...
}

 

示例1:

object KeyedProcessTest {

  def main(args: Array[String]): Unit = {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.setParallelism(1)
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)

    val stream = env.socketTextStream("localhost", 7777)


    val dataStream: DataStream[SensorReading] = stream
      .map(data => {
        val dataArray = data.split(",")
        SensorReading(dataArray(0), dataArray(1).toLong, dataArray(2).toDouble)
      })
      .assignTimestampsAndWatermarks(new BoundedOutOfOrdernessTimestampExtractor[SensorReading](Time.seconds(1)) {
        override def extractTimestamp(element: SensorReading): Long = {
          element.timestamp * 1000
        }
      })


    val processStream = dataStream.keyBy(_.id)
      .process(new TempIncrementAlertFunction())

    processStream.print()

    env.execute()
  }
}

class TempIncrementAlertFunction() extends KeyedProcessFunction[String, SensorReading, String] {

  // 定义温度的状态
  lazy val lastTemp: ValueState[Double] = getRuntimeContext.getState(new ValueStateDescriptor[Double]("lastTemp", Types.of[Double]))
  // 定义计时器时间戳
  lazy val currentTimer: ValueState[Long] = getRuntimeContext.getState(new ValueStateDescriptor[Long]("currentTimer", Types.of[Long]))

  override def processElement(value: SensorReading,
                              ctx: KeyedProcessFunction[String, SensorReading, String]#Context,
                              out: Collector[String]): Unit = {
    val preTemp = lastTemp.value()
    lastTemp.update(value.temperature)
    val curTimerTs = currentTimer.value()
    // 温度上升,且之前没有定义过计时器
    if (value.temperature > preTemp && curTimerTs == 0) {
      // 定义 10 秒后触发报警定时器
      val timerTs = ctx.timerService().currentProcessingTime() + 10000L
      ctx.timerService().registerProcessingTimeTimer(timerTs)
      currentTimer.update(timerTs)
    } else if (preTemp > value.temperature || preTemp == 0.0){
      // 温度下降或是第一条数据,取消定时并清空状态
      ctx.timerService().deleteProcessingTimeTimer(curTimerTs)
      currentTimer.clear()
    }
  }

  override def onTimer(timestamp: Long,
                       ctx: KeyedProcessFunction[String, SensorReading, String]#OnTimerContext,
                       out: Collector[String]): Unit = {
    // 输出报警信息
    out.collect("sensor_" + ctx.getCurrentKey + " 温度连续上升")
    currentTimer.clear()
  }
}

 

示例2:侧路输出

object SideOutProcessTest {

  def main(args: Array[String]): Unit = {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.setParallelism(1)
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)

    val stream = env.socketTextStream("localhost", 7777)


    val dataStream: DataStream[SensorReading] = stream
      .map(data => {
        val dataArray = data.split(",")
        SensorReading(dataArray(0), dataArray(1).toLong, dataArray(2).toDouble)
      })
      .assignTimestampsAndWatermarks(new BoundedOutOfOrdernessTimestampExtractor[SensorReading](Time.seconds(1)) {
        override def extractTimestamp(element: SensorReading): Long = {
          element.timestamp * 1000
        }
      })


    val processStream = dataStream
      .process(new FreezingAlert())

    processStream.print()
    processStream.getSideOutput(new OutputTag[String]("freezing alert")).print("alert data")

    env.execute()
  }
}

class FreezingAlert() extends ProcessFunction[SensorReading, SensorReading] {

  lazy val alertTag: OutputTag[String] = new OutputTag[String]("freezing alert")

  override def processElement(value: SensorReading,
                              ctx: ProcessFunction[SensorReading, SensorReading]#Context,
                              out: Collector[SensorReading]): Unit = {
    if (value.temperature < 30) {
      ctx.output(alertTag, "freezing alert for " + value.id)
    } else {
      out.collect(value)
    }
  }
}

 

233

posted on 2020-04-06 09:33  Lemo_wd  阅读(1627)  评论(0编辑  收藏  举报

导航