import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler}
import spark.implicits._
case class Person(id: Long, category: String, age: Long)
val df = spark.createDataFrame(
Seq(Person(0, "a", 10),
Person(1, "b", 5),
Person(2, "c", 4),
Person(3, "a", 11),
Person(4, "a", 20),
Person(5, "c", 1)
))
val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex")
/**使用OneHotEncoder将分类变量转换为二进制稀疏向量*/
val encoder = new OneHotEncoder().setInputCol(indexer.getOutputCol).setOutputCol("categoryClassVec")
val assembler = new VectorAssembler().setInputCols(Array("categoryClassVec","age")).setOutputCol("features")
val pipeline = new Pipeline()
.setStages(Array(indexer,encoder,assembler))
val featureDF = pipeline.fit(df).transform(df)
featureDF.show()
+---+--------+---+-------------+----------------+--------------+
| id|category|age|categoryIndex|categoryClassVec| features|
+---+--------+---+-------------+----------------+--------------+
| 0| a| 10| 0.0| (2,[0],[1.0])|[1.0,0.0,10.0]|
| 1| b| 5| 2.0| (2,[],[])| [0.0,0.0,5.0]|
| 2| c| 4| 1.0| (2,[1],[1.0])| [0.0,1.0,4.0]|
| 3| a| 11| 0.0| (2,[0],[1.0])|[1.0,0.0,11.0]|
| 4| a| 20| 0.0| (2,[0],[1.0])|[1.0,0.0,20.0]|
| 5| c| 1| 1.0| (2,[1],[1.0])| [0.0,1.0,1.0]|
+---+--------+---+-------------+----------------+--------------+
- python - How to interpret results of Spark OneHotEncoder - Stack Overflow