Machine Learnign 24 -- 聚类分析
聚类(Cluster)分析又称群分析,是研究分类(样品或指标)问题的一种统计分析方法,也是数据挖掘的一个重要算法。
聚类分析是由若干模式(Pattern)组成的,通常,模式是一个度量(Measuremern)的向量,或者是多维空间中的一个点。聚类分析以相似性为基础,在同一个聚类中的模式之间比不在同一个聚类中的模式之间具有更多的相似性。
许多聚类算法在小于200个数据对象的小数据集合上工作得很好;但是,一个大规模数据库可能包含几百万个对象,在这样的大数据集合样本上进行聚类可能会导致有偏差的结果,因此需要具有高度可伸缩性的聚类算法。
许多聚类算法擅长处理低维的数据,可能只涉及两维或三维。人类的眼睛在最多三维的情况下能够很好的判断聚类的质量。在高维空间中,聚类数据对象是非常有挑战性的,特别是考虑到这样的数据可能分布非常稀疏,而且高度偏斜。因此,对于多维的数据进行聚类算法时首先要进行降维处理。
下面通过一个例子来展示聚类算法的应用。在这里使用UCI数据仓库中的wine数据集(http://archive/ics.uci.edu/ml/datasets/Wine),这个数据集中包含13个数据特征,并且数据被分为三个类别,通过KMean算法自动聚类。
1 #聚类分析 2 from pandas import read_csv 3 from sklearn.cluster import KMeans 4 from sklearn.decomposition import PCA 5 from sklearn.preprocessing import scale 6 from sklearn.preprocessing import StandardScaler 7 from matplotlib import pyplot as plt 8 from mpl_toolkits.mplot3d import Axes3D 9 import numpy as np 10 from sklearn import metrics 11 12 #导入数据 13 fileanem='/home/aistudio/work/wine.data.csv' 14 names={'class','Alcohol','MalicAcid','Ash','AlclinityOfAsh','Magnesium','TotalPhenols','Flavanoids','NonflayanoidPhenols','Proathocyanins','ColorIntensityt','Hue','OD280/OD315','Proline'} 15 dataset=read_csv(fileanem,names=names) 16 dataset['class']=dataset['class'].replace(to_replace=[1,2,3],value=[0,1,2]) 17 array=dataset.values 18 x=array[:,1:13] 19 y=array[:,0] 20 21 #数据降维 22 pca=PCA(n_components=3) 23 x_scale=StandardScaler().fit_transform(x) 24 x_reduce=pca.fit_transform(scale(x_scale)) 25 26 #模型训练 27 model=KMeans(n_clusters=3) 28 model.fit(x_reduce) 29 labels=model.labels_ 30 #print(labels) 31 32 #输出模型的准确度 33 print('%.3f %.3f %.3f %.3f %.3f %.3f' % (metrics.homogeneity_score(y,labels), 34 metrics.completeness_score(y,labels), 35 metrics.v_measure_score(y,labels), 36 metrics.adjusted_rand_score(y,labels), 37 metrics.adjusted_mutual_info_score(y,labels), 38 metrics.silhouette_score(x_reduce,labels))) 39 40 #绘制模型的分布图 41 fig=plt.figure() 42 ax=Axes3D(fig,rect=[0,0,.95,1],elev=48,azim=134) 43 ax.scatter(x_reduce[:,0],x_reduce[:,1],x_reduce[:,2],c=labels.astype(np.float)) 44 plt.show()
测试结果如下:
0.735 0.731 0.733 0.740 0.728 0.413