scala实现kmeans算法

算法的概念不做过都解释,google一下一大把。直接贴上代码,有比较详细的注释了。

主程序:

 1 import scala.io.Source
 2 import scala.util.Random
 3 
 4 /**
 5  * @author vincent
 6  *
 7  */
 8 object LocalKMeans {
 9     def main(args: Array[String]) {
10         val fileName = "/home/vincent/kmeans_data.txt"
11         val knumbers = 3
12         val rand = new Random()
13 
14         //  读取文本数据
15         val lines = Source.fromFile(fileName).getLines.toArray
16         val points = lines.map(line => {
17             val parts = line.split("\t").map(_.toDouble)
18             new Point(parts(0), parts(1))
19         }).toArray
20         
21         //  随机初始化k个质心
22         val centroids = new Array[Point](knumbers)
23         for (i <- 0 until knumbers) {
24             centroids(i) = points(new Random().nextInt(points.length))
25         }
26         val startTime = System.currentTimeMillis()
27         println("initialize centroids:\n" + centroids.mkString("\n") + "\n")
28         println("test points: \n" + points.mkString("\n") + "\n")
29 
30         val resultCentroids = kmeans(points, centroids, 0.001)
31         
32         val endTime = System.currentTimeMillis()
33         val runTime = endTime - startTime
34         println("run Time: " + runTime + "\nFinal centroids: \n" + resultCentroids.mkString("\n"))
35     }
36     
37     //  算法的核心函数
38     def kmeans(points: Seq[Point], centroids: Seq[Point], epsilon: Double): Seq[Point] = {
39         //  最近质心为key值,将数据集分簇
40         val clusters = points.groupBy(closestCentroid(centroids, _))
41         println("clusters: \n" + clusters.mkString("\n") + "\n")
42         //  分别计算簇中数据集的平均数,得到每个簇的新质心
43         val newCentroids = centroids.map(oldCentroid => {
44             clusters.get(oldCentroid) match {
45                 case Some(pointsInCluster) => pointsInCluster.reduceLeft(_ + _) / pointsInCluster.length
46                 case None => oldCentroid
47             }
48         })
49         //  计算新质心相对与旧质心的偏移量
50         val movement = (centroids zip newCentroids).map({ case (a, b) => a distance b })
51         println("Centroids changed by\n" + movement.map(d => "%3f".format(d)).mkString("(", ", ", ")")
52             + "\nto\n" + newCentroids.mkString(", ") + "\n")
53         //  根据偏移值大小决定是否继续迭代,epsilon为最小偏移值
54         if (movement.exists(_ > epsilon))
55             kmeans(points, newCentroids, epsilon)
56         else
57             return newCentroids
58     }
59 
60     //  计算最近质心
61     def closestCentroid(centroids: Seq[Point], point: Point) = {
62         centroids.reduceLeft((a, b) => if ((point distance a) < (point distance b)) a else b)
63     }
64 }

 

自定义Point类:

 1 /**
 2  * @author vincent
 3  *
 4  */
 5 object Point {
 6     def random() = {
 7         new Point(math.random * 50, math.random * 50)
 8     }
 9 }
10 
11 case class Point(val x: Double, val y: Double) {
12     def +(that: Point) = new Point(this.x + that.x, this.y + that.y)
13     def -(that: Point) = new Point(this.x - that.x, this.y - that.y)
14     def /(d: Double) = new Point(this.x / d, this.y / d)
15     def pointLength = math.sqrt(x * x + y * y)
16     def distance(that: Point) = (this - that).pointLength
17     override def toString = format("(%.3f, %.3f)", x, y)
18 }

测试数据集:

12.044996    36.412378
31.881257    33.677009
41.703139    46.170517
43.244406    6.991669
19.319000    27.926669
3.556824    40.935215
29.328655    33.303675
43.702858    22.305344
28.978940    28.905725
10.426760    40.311507

 

 

posted @ 2013-09-02 17:18  vincent_hv  阅读(4668)  评论(1编辑  收藏  举报