k-means 图像分割

经典的无监督聚类算法,不多说,上代码。

 1 import numpy as np
 2 import pandas as pd
 3 import copy
 4 import matplotlib.pyplot as plt
 5 
 6 pic = plt.imread('cs-nonoise.jpg')
 7 # plt.imshow(pic)
 8 # pic.shape   #(1200, 800)
 9 data = pic.reshape(-1, 3)
10 
11 
12 def kmeans_wave(n, k, data):  # n为迭代次数, k为聚类数目, data为输入数据
13     data_new = copy.deepcopy(data)
14     data_new = np.column_stack((data_new, np.ones(1200*800)))   # 扩展一个维度用来存放标签
15     center_point = np.random.choice(1200*800, k, replace=False)     # 随机选择初始点
16     center = data_new[center_point,:]
17     distance = [[] for i in range(k)]   # 距离度量
18     for i in range(n):
19          for j in range(k):
20              distance[j] = np.sqrt(np.sum(np.square(data_new - np.array(center[j])), axis=1))     # 更新距离
21          data_new[:,3] = np.argmin(np.array(distance), axis=0)   # 将最小距离的类别标签作为当前数据的类别
22          for l in range(k):
23              center[l] = np.mean(data_new[data_new[:,3]==l], axis=0)   # 更新聚类中心
24 
25     return data_new
26 
27 
28 if __name__ == '__main__':
29     data_new = kmeans_wave(100,5,data)
30     print(data_new.shape)
31     # data_new = np.delete(data_new, 3, axis=1)
32     # print(data_new.shape)
33     pic_new = data_new[:,3].reshape(1200,800)   # 将多个标签展示出来
34     plt.imshow(pic_new)
35     plt.show()

结果:

        

                          原图                                                                                                                            k=5 结果图

 

posted @ 2017-10-22 20:49  三年一梦  阅读(4507)  评论(1编辑  收藏  举报