scala实现kmeans算法
算法的概念不做过都解释,google一下一大把。直接贴上代码,有比较详细的注释了。
主程序:
1 import scala.io.Source 2 import scala.util.Random 3 4 /** 5 * @author vincent 6 * 7 */ 8 object LocalKMeans { 9 def main(args: Array[String]) { 10 val fileName = "/home/vincent/kmeans_data.txt" 11 val knumbers = 3 12 val rand = new Random() 13 14 // 读取文本数据 15 val lines = Source.fromFile(fileName).getLines.toArray 16 val points = lines.map(line => { 17 val parts = line.split("\t").map(_.toDouble) 18 new Point(parts(0), parts(1)) 19 }).toArray 20 21 // 随机初始化k个质心 22 val centroids = new Array[Point](knumbers) 23 for (i <- 0 until knumbers) { 24 centroids(i) = points(new Random().nextInt(points.length)) 25 } 26 val startTime = System.currentTimeMillis() 27 println("initialize centroids:\n" + centroids.mkString("\n") + "\n") 28 println("test points: \n" + points.mkString("\n") + "\n") 29 30 val resultCentroids = kmeans(points, centroids, 0.001) 31 32 val endTime = System.currentTimeMillis() 33 val runTime = endTime - startTime 34 println("run Time: " + runTime + "\nFinal centroids: \n" + resultCentroids.mkString("\n")) 35 } 36 37 // 算法的核心函数 38 def kmeans(points: Seq[Point], centroids: Seq[Point], epsilon: Double): Seq[Point] = { 39 // 最近质心为key值,将数据集分簇 40 val clusters = points.groupBy(closestCentroid(centroids, _)) 41 println("clusters: \n" + clusters.mkString("\n") + "\n") 42 // 分别计算簇中数据集的平均数,得到每个簇的新质心 43 val newCentroids = centroids.map(oldCentroid => { 44 clusters.get(oldCentroid) match { 45 case Some(pointsInCluster) => pointsInCluster.reduceLeft(_ + _) / pointsInCluster.length 46 case None => oldCentroid 47 } 48 }) 49 // 计算新质心相对与旧质心的偏移量 50 val movement = (centroids zip newCentroids).map({ case (a, b) => a distance b }) 51 println("Centroids changed by\n" + movement.map(d => "%3f".format(d)).mkString("(", ", ", ")") 52 + "\nto\n" + newCentroids.mkString(", ") + "\n") 53 // 根据偏移值大小决定是否继续迭代,epsilon为最小偏移值 54 if (movement.exists(_ > epsilon)) 55 kmeans(points, newCentroids, epsilon) 56 else 57 return newCentroids 58 } 59 60 // 计算最近质心 61 def closestCentroid(centroids: Seq[Point], point: Point) = { 62 centroids.reduceLeft((a, b) => if ((point distance a) < (point distance b)) a else b) 63 } 64 }
自定义Point类:
1 /** 2 * @author vincent 3 * 4 */ 5 object Point { 6 def random() = { 7 new Point(math.random * 50, math.random * 50) 8 } 9 } 10 11 case class Point(val x: Double, val y: Double) { 12 def +(that: Point) = new Point(this.x + that.x, this.y + that.y) 13 def -(that: Point) = new Point(this.x - that.x, this.y - that.y) 14 def /(d: Double) = new Point(this.x / d, this.y / d) 15 def pointLength = math.sqrt(x * x + y * y) 16 def distance(that: Point) = (this - that).pointLength 17 override def toString = format("(%.3f, %.3f)", x, y) 18 }
测试数据集:
12.044996 36.412378
31.881257 33.677009
41.703139 46.170517
43.244406 6.991669
19.319000 27.926669
3.556824 40.935215
29.328655 33.303675
43.702858 22.305344
28.978940 28.905725
10.426760 40.311507