k-means 图像分割
经典的无监督聚类算法,不多说,上代码。
1 import numpy as np 2 import pandas as pd 3 import copy 4 import matplotlib.pyplot as plt 5 6 pic = plt.imread('cs-nonoise.jpg') 7 # plt.imshow(pic) 8 # pic.shape #(1200, 800) 9 data = pic.reshape(-1, 3) 10 11 12 def kmeans_wave(n, k, data): # n为迭代次数, k为聚类数目, data为输入数据 13 data_new = copy.deepcopy(data) 14 data_new = np.column_stack((data_new, np.ones(1200*800))) # 扩展一个维度用来存放标签 15 center_point = np.random.choice(1200*800, k, replace=False) # 随机选择初始点 16 center = data_new[center_point,:] 17 distance = [[] for i in range(k)] # 距离度量 18 for i in range(n): 19 for j in range(k): 20 distance[j] = np.sqrt(np.sum(np.square(data_new - np.array(center[j])), axis=1)) # 更新距离 21 data_new[:,3] = np.argmin(np.array(distance), axis=0) # 将最小距离的类别标签作为当前数据的类别 22 for l in range(k): 23 center[l] = np.mean(data_new[data_new[:,3]==l], axis=0) # 更新聚类中心 24 25 return data_new 26 27 28 if __name__ == '__main__': 29 data_new = kmeans_wave(100,5,data) 30 print(data_new.shape) 31 # data_new = np.delete(data_new, 3, axis=1) 32 # print(data_new.shape) 33 pic_new = data_new[:,3].reshape(1200,800) # 将多个标签展示出来 34 plt.imshow(pic_new) 35 plt.show()
结果:
原图 k=5 结果图