import os import sys import numpy as np import matplotlib.pyplot as plt import random # https://www.jianshu.com/p/c31b08179655 def load_data(path="data.txt"): f = open(path, encoding='utf-8') data = [] for line in f.readlines(): try: line = line.strip().split() # print(line) data.append([float(line[0]), float(line[1])]) except: pass f.close() return np.array(data) def showCluster(dataset, k, centroids, cluster_assignment): numSamples, dim = dataset.shape mark = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr'] # draw all samples for i in range(numSamples): mark_idx = int(cluster_assignment[i, 0]) plt.plot(dataset[i, 0], dataset[i, 1], mark[mark_idx]) mark = ['Dr', 'Db', 'Dg', 'Dk', '^b', '+b', 'sb', 'db'] for i in range(k): plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize=12) plt.show() def cal_dist(v1, v2): dist = np.sqrt(np.sum((v1 - v2)*(v1 - v2))) return dist def select_init_centroids(dataset, k): print(dataset.shape) samples_num, dim = dataset.shape centroids = np.zeros((k, dim)) cnt = 0 selected_idxs = [] while cnt < k: idx = random.randint(0, samples_num-1) if idx in selected_idxs: continue selected_idxs.append(idx) centroids[cnt] = dataset[idx] cnt += 1 return centroids def kmeans(dataset, k): samples_num, dim = dataset.shape centroids = select_init_centroids(dataset, k) cluster_assignment = np.zeros((samples_num, 2)) cluster_changed = True while cluster_changed: cluster_changed = False for i in range(samples_num): min_dist = 1000000.0 min_idx = 0 # step2: find the centroid who is closet for j in range(k): dist = cal_dist(dataset[i], centroids[j]) if dist < min_dist: min_dist = dist min_idx = j # step 3: update it's cluster if cluster_assignment[i][0] != min_idx: cluster_assignment[i][0] = min_idx cluster_assignment[i][1] = min_dist cluster_changed = True # step4: update centroids for j in range(k): points_in_cluster = dataset[cluster_assignment[:, 0]==j] centroids[j, :] = np.mean(points_in_cluster, axis=0) return centroids, cluster_assignment def main(): dataset = load_data() # print(data[:10]) centroids, cluster_assignment = kmeans(dataset, 4) # centroids = select_init_centroids(dataset, 4) print(centroids) showCluster(dataset, 4, centroids, cluster_assignment) main()
最终效果如下:
附数据:
1.658985 4.285136 -3.453687 3.424321 4.838138 -1.151539 -5.379713 -3.362104 0.972564 2.924086 -3.567919 1.531611 0.450614 -3.302219 -3.487105 -1.724432 2.668759 1.594842 -3.156485 3.191137 3.165506 -3.999838 -2.786837 -3.099354 4.208187 2.984927 -2.123337 2.943366 0.704199 -0.479481 -0.392370 -3.963704 2.831667 1.574018 -0.790153 3.343144 2.943496 -3.357075 -3.195883 -2.283926 2.336445 2.875106 -1.786345 2.554248 2.190101 -1.906020 -3.403367 -2.778288 1.778124 3.880832 -1.688346 2.230267 2.592976 -2.054368 -4.007257 -3.207066 2.257734 3.387564 -2.679011 0.785119 0.939512 -4.023563 -3.674424 -2.261084 2.046259 2.735279 -3.189470 1.780269 4.372646 -0.822248 -2.579316 -3.497576 1.889034 5.190400 -0.798747 2.185588 2.836520 -2.658556 -3.837877 -3.253815 2.096701 3.886007 -2.709034 2.923887 3.367037 -3.184789 -2.121479 -4.232586 2.329546 3.179764 -3.284816 3.273099 3.091414 -3.815232 -3.762093 -2.432191 3.542056 2.778832 -1.736822 4.241041 2.127073 -2.983680 -4.323818 -3.938116 3.792121 5.135768 -4.786473 3.358547 2.624081 -3.260715 -4.009299 -2.978115 2.493525 1.963710 -2.513661 2.642162 1.864375 -3.176309 -3.171184 -3.572452 2.894220 2.489128 -2.562539 2.884438 3.491078 -3.947487 -2.565729 -2.012114 3.332948 3.983102 -1.616805 3.573188 2.280615 -2.559444 -2.651229 -3.103198 2.321395 3.154987 -1.685703 2.939697 3.031012 -3.620252 -4.599622 -2.185829 4.196223 1.126677 -2.133863 3.093686 4.668892 -2.562705 -2.793241 -2.149706 2.884105 3.043438 -2.967647 2.848696 4.479332 -1.764772 -4.905566 -2.911070