spark ml中一个比较通用的transformer
spark ml中有许多好用的transformer,很方便用来做特征的处理,比如Tokenizer, StopWordsRemover等,具体可参看文档:http://spark.apache.org/docs/2.1.0/ml-features.html . 但是呢,这些都是一些特定的操作,组内的同事提了一个需求,能不能写一个通用的模板,用来做特征转化,让代码看起来比较整洁规整。后来经过参考spark中那些transformer的写法, 弄了一个比较通用的模板,只能说比较通用,还是有些需求不能满足的。模板其实很简单,几行代码就搞定了,如下:
import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable} import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.types._ // 对泛型参数进行说明: I表示 InputCol的类型, O 表示OutputCol的类型 class KgTransformer[I, O](override val uid: String) extends UnaryTransformer[I, O, KgTransformer[I, O]] with DefaultParamsWritable { var kgOutputDataType: DataType = _ var f: I => O = _ // 第一个参数表示OutputCol的类型, 这里需要使用spark ml中类型,其它更多的使用方式可以参考官方文档或者其它源码中写的相似代码 // 第二个参数表示 作用在 InputCol上的操作函数 def this(kgOutputDataType: DataType, f: I => O) = { this(Identifiable.randomUID("kgTransformer")) this.kgOutputDataType = kgOutputDataType this.f = f } override def outputDataType: DataType = kgOutputDataType //自定义实现验证, 这里默认没验证 override def validateInputType(inputType: DataType): Unit = {} // 这里是申明一个名为createTransformFunc的函数, 返回值是一个函数:返回一个参数类型为I, 返回值类型为 O 的函数 override def createTransformFunc: (I) => O = { //其实这玩意就是想搞一个map的函数体来 (in: I) => f(in) //返回一个函数 } }
以上这个工具类模型就写好了,以后对于一些操作都可以用这种通用的写法,如下:
// 使用方式举例: object TestKgTransformer { def main(args: Array[String]) { val spark = SparkSession.builder().appName("TestKgTransformer").master("local").getOrCreate() spark.sparkContext.setLogLevel("ERROR") import spark.implicits._ val data: DataFrame = spark.createDataFrame(Seq( (0.0, "a;c;e"), (1.0, "b;f;g"), (2.0, "c;c"), (3.0, "c;k;f;e;c") )).toDF("id", "words") data.show(false) // 输入列为id, 输出列为kgIdOut, 操作为给 id列的值加上5 val kgTrans1 = new KgTransformer[Double, Double](DataTypes.DoubleType, _ + 5).setInputCol("id").setOutputCol("kgIdOut") val res1 = kgTrans1.transform(data) res1.show(false) // 输入列为words, 输出列为 kgWordsOut ,操作为对 words 进行切割 val kgTrans2 = new KgTransformer[String, Array[String]](new ArrayType(StringType, true), _.split("\\;")).setInputCol("words").setOutputCol("kgWordsOut") kgTrans2.transform(res1).show(false) spark.stop() } }
就只有这么多啦,代码量很少,不过挺实用的,上面的代码是可以直接运行的。
结果如下:
+---+---------+
|id |words |
+---+---------+
|0.0|a;c;e |
|1.0|b;f;g |
|2.0|c;c |
|3.0|c;k;f;e;c|
+---+---------+
+---+---------+-------+
|id |words |kgIdOut|
+---+---------+-------+
|0.0|a;c;e |5.0 |
|1.0|b;f;g |6.0 |
|2.0|c;c |7.0 |
|3.0|c;k;f;e;c|8.0 |
+---+---------+-------+
+---+---------+-------+---------------+
|id |words |kgIdOut|kgWordsOut |
+---+---------+-------+---------------+
|0.0|a;c;e |5.0 |[a, c, e] |
|1.0|b;f;g |6.0 |[b, f, g] |
|2.0|c;c |7.0 |[c, c] |
|3.0|c;k;f;e;c|8.0 |[c, k, f, e, c]|
+---+---------+-------+---------------+