不积跬步,无以至千里;不积小流,无以成江海。——荀子

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

 

 

 

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import random

# https://www.jianshu.com/p/c31b08179655

def load_data(path="data.txt"):
    f = open(path, encoding='utf-8')
    data = []
    for line in f.readlines():
        try:
            line = line.strip().split()
            # print(line)
            data.append([float(line[0]), float(line[1])])
        except:
            pass
    f.close()

    return np.array(data)

def showCluster(dataset, k, centroids, cluster_assignment):
    numSamples, dim = dataset.shape
    mark = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr']
    # draw all samples  
    for i in range(numSamples):
        mark_idx = int(cluster_assignment[i, 0])
        plt.plot(dataset[i, 0], dataset[i, 1], mark[mark_idx])

    mark = ['Dr', 'Db', 'Dg', 'Dk', '^b', '+b', 'sb', 'db']
    for i in range(k):
        plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize=12)
    plt.show()


def cal_dist(v1, v2):
    dist = np.sqrt(np.sum((v1 - v2)*(v1 - v2)))
    return dist

def select_init_centroids(dataset, k):
    print(dataset.shape)
    samples_num, dim = dataset.shape
    centroids = np.zeros((k, dim))
    cnt = 0
    selected_idxs = []
    while cnt < k:
        idx = random.randint(0, samples_num-1)
        if idx in selected_idxs:
            continue
        selected_idxs.append(idx)
        centroids[cnt] = dataset[idx]
        cnt += 1
    return centroids

def kmeans(dataset, k):
    samples_num, dim = dataset.shape
    centroids = select_init_centroids(dataset, k)
    cluster_assignment = np.zeros((samples_num, 2))
    cluster_changed = True
    while cluster_changed:
        cluster_changed = False
        for i in range(samples_num):
            min_dist = 1000000.0
            min_idx = 0
            
            # step2: find the centroid who is closet
            for j in range(k):
                dist = cal_dist(dataset[i], centroids[j])
                if dist < min_dist:
                    min_dist = dist
                    min_idx = j
            
            # step 3: update it's cluster
            if cluster_assignment[i][0] != min_idx:
                cluster_assignment[i][0] = min_idx
                cluster_assignment[i][1] = min_dist
                cluster_changed = True

        # step4: update centroids
        for j in range(k):
            points_in_cluster = dataset[cluster_assignment[:, 0]==j]
            centroids[j, :] = np.mean(points_in_cluster, axis=0)

    return centroids, cluster_assignment

def main():
    dataset = load_data()
    # print(data[:10])
    centroids, cluster_assignment = kmeans(dataset, 4)
    # centroids = select_init_centroids(dataset, 4)
    print(centroids)
    showCluster(dataset, 4, centroids, cluster_assignment)

main()

 

 

最终效果如下:

 附数据:

1.658985    4.285136  
-3.453687   3.424321  
4.838138    -1.151539  
-5.379713   -3.362104  
0.972564    2.924086  
-3.567919   1.531611  
0.450614    -3.302219  
-3.487105   -1.724432  
2.668759    1.594842  
-3.156485   3.191137  
3.165506    -3.999838  
-2.786837   -3.099354  
4.208187    2.984927  
-2.123337   2.943366  
0.704199    -0.479481  
-0.392370   -3.963704  
2.831667    1.574018  
-0.790153   3.343144  
2.943496    -3.357075  
-3.195883   -2.283926  
2.336445    2.875106  
-1.786345   2.554248  
2.190101    -1.906020  
-3.403367   -2.778288  
1.778124    3.880832  
-1.688346   2.230267  
2.592976    -2.054368  
-4.007257   -3.207066  
2.257734    3.387564  
-2.679011   0.785119  
0.939512    -4.023563  
-3.674424   -2.261084  
2.046259    2.735279  
-3.189470   1.780269  
4.372646    -0.822248  
-2.579316   -3.497576  
1.889034    5.190400  
-0.798747   2.185588  
2.836520    -2.658556  
-3.837877   -3.253815  
2.096701    3.886007  
-2.709034   2.923887  
3.367037    -3.184789  
-2.121479   -4.232586  
2.329546    3.179764  
-3.284816   3.273099  
3.091414    -3.815232  
-3.762093   -2.432191  
3.542056    2.778832  
-1.736822   4.241041  
2.127073    -2.983680  
-4.323818   -3.938116  
3.792121    5.135768  
-4.786473   3.358547  
2.624081    -3.260715  
-4.009299   -2.978115  
2.493525    1.963710  
-2.513661   2.642162  
1.864375    -3.176309  
-3.171184   -3.572452  
2.894220    2.489128  
-2.562539   2.884438  
3.491078    -3.947487  
-2.565729   -2.012114  
3.332948    3.983102  
-1.616805   3.573188  
2.280615    -2.559444  
-2.651229   -3.103198  
2.321395    3.154987  
-1.685703   2.939697  
3.031012    -3.620252  
-4.599622   -2.185829  
4.196223    1.126677  
-2.133863   3.093686  
4.668892    -2.562705  
-2.793241   -2.149706  
2.884105    3.043438  
-2.967647   2.848696  
4.479332    -1.764772  
-4.905566   -2.911070 

 

posted on 2020-05-14 11:16  hejunlin  阅读(718)  评论(0编辑  收藏  举报