8.5 聚类算法
- KMeans 是一个迭代求解的聚类算法。
- 其属于划分(Partitioning)型的聚类方法,即首先创建K个划分,然后迭代地将样本从一个划分转移到另一个划分来改善最终聚类的质量。
ML包下的KMeans方法位于org.apache.spark.ml.clustering包下,其过程大致如下:
- 根据给定的k值,选取k个样本点作为初始划分中心
- 计算所有样本点到每一个划分中心的距离,并将所有样本点划分到距离最近的划分中心
- 计算每个划分中样本点的平均值,将其作为新的中心;循环进行2~3步直至达到最大迭代次数,或划分中心的变化小于某一预定义阈值
数据集:使用UCI数据集中的鸢尾花数据Iris进行实验,它可以在iris获取,Iris数据的样本容量为150,有四个实数值的特征,分别代表花朵四个部位的尺寸,以及该样本对应鸢尾花的亚种类型(共有3种亚种类型)
1 2 3 4 5 6 | 5.1 , 3.5 , 1.4 , 0.2 ,setosa ... 5.4 , 3.0 , 4.5 , 1.5 ,versicolor ... 7.1 , 3.0 , 5.9 , 2.1 ,virginica ... |
在使用前,引入需要的包:
1 2 | import org.apache.spark.ml.clustering.{KMeans,KMeansModel} import org.apache.spark.ml.linalg.Vectors |
开启RDD的隐式转换:
1 | import spark.implicits._ |
为了便于生成相应的DataFrame,这里定义一个名为model_instance的case class作为DataFrame每一行(一个数据样本)的数据类型
1 2 | scala> case class model_instance (features: Vector) defined class model_instance |
在定义数据类型完成后,即可将数据读入RDD[model_instance]的结构中,并通过RDD的隐式转换.toDF()方法完成RDD到DataFrame的转换:
1 2 3 4 5 6 7 | scala> val rawData = sc.textFile( "file:///usr/local/spark/iris.txt" ) rawData: org.apache.spark.rdd.RDD[String] = iris.csv MapPartitionsRDD[ 48 ] at textFile at <console>: 33 scala> val df = rawData. map (line = > | { model_instance( Vectors.dense(line.split( "," ). filter (p = > p.matches( "\\d*(\\.?)\\d*" )) | . map (_.toDouble)) )}).toDF() df: org.apache.spark.sql.DataFrame = [features: vector] |
在得到数据后,我们即可通过ML包的固有流程:创建Estimator并调用其fit()方法来生成相应的Transformer对象,很显然,在这里KMeans类是Estimator,而用于保存训练后模型的KMeansModel类则属于Transformer。
1 2 3 4 5 6 | scala> val kmeansmodel = new KMeans(). | setK( 3 ). | setFeaturesCol( "features" ). | setPredictionCol( "prediction" ). | fit(df) kmeansmodel: org.apache.spark.ml.clustering.KMeansModel = kmeans_d8c043c3c339 |
与MLlib中的实现不同,KMeansModel作为一个Transformer,不再提供predict()样式的方法,而是提供了一致性的transform()方法,用于将存储在DataFrame中的给定数据集进行整体处理,生成带有预测簇标签的数据集
1 2 | scala> val results = kmeansmodel.transform(df) results: org.apache.spark.sql.DataFrame = [features: vector, prediction: int ] |
为了方便观察,我们可以使用collect()方法,该方法将DataFrame中所有的数据组织成一个Array对象进行返回:
1 2 3 4 5 6 7 8 9 10 | scala> results.collect().foreach( | row = > { | println( row( 0 ) + " is predicted as cluster " + row( 1 )) | }) [ 5.1 , 3.5 , 1.4 , 0.2 ] is predicted as cluster 2 ... [ 6.3 , 3.3 , 6.0 , 2.5 ] is predicted as cluster 1 ... [ 5.8 , 2.7 , 5.1 , 1.9 ] is predicted as cluster 0 ... |
也可以通过KMeansModel类自带的clusterCenters属性获取到模型的所有聚类中心情况:
1 2 3 4 5 6 7 | scala> kmeansmodel.clusterCenters.foreach( | center = > { | println( "Clustering Center:" + center) | }) Clustering Center:[ 5.883606557377049 , 2.740983606557377 , 4.388524590163936 , 1.4344262295081964 ] Clustering Center:[ 6.8538461538461535 , 3.076923076923076 , 5.715384615384614 , 2.053846153846153 ] Clustering Center:[ 5.005999999999999 , 3.4180000000000006 , 1.4640000000000002 , 0.2439999999999999 ] |
与MLlib下的实现相同,KMeansModel类也提供了计算 集合内误差平方和(Within Set Sum of Squared Error, WSSSE) 的方法来度量聚类的有效性,在真实K值未知的情况下,该值的变化可以作为选取合适K值的一个重要参考:
1 2 | scala> kmeansmodel.computeCost(df) res15: Double = 78.94084142614622 |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现