前言:机器学习所使用的数据大多具有很多的特征,想要直观的对数据的分布和特征之间的关系进行观测,需要将数据的主要特征提取出来,降低到三维及三维以下的空间来展示。

PCA(主成分分析)是常用的用于降维的方法,本文通过PCA对数据进行降维,再对降维后的数据用K-means算法聚类,以达到在低维空间可直观观察的目的。

(本文使用数字识别数据集,共1797个样本,每个样本是64维向量(8*8),分为10个数字)

在二维空间可视化的效果如图:黑色的为样本点,白色叉号是聚类中心。

 1 from time import time
 2 import numpy as np
 3 import pandas as pd
 4 import matplotlib as mpl
 5 import matplotlib.pyplot as plt
 6 import sklearn
 7 from sklearn import datasets
 8 from sklearn.datasets import load_digits
 9 from sklearn.preprocessing import scale
10 from sklearn import metrics
11 from sklearn.cluster import KMeans
12 from sklearn.decomposition import PCA
13 
14 np.random.seed(42)
15 
16 #1. 加载数据集
17 digits = load_digits()
18 
19 #2. 对数据集进行归一化缩放处理
20 data = scale(digits.data)
21 
22 #3. 数据基本信息:1797个样本,每个样本是64维向量(8*8),分为10个数字
23 n_samples, n_features = data.shape
24 n_digits = len(np.unique(digits.target))
25 labels = digits.target
26 
27 #4. 使用PCA提取2个主成份
28 reduced_data = PCA(n_components=2).fit_transform(data)
29 
30 #5. 使用K均值进行聚类,目标聚为10类
31 kmeans = KMeans(init='k-means++', n_clusters=n_digits, n_init=10)
32 
33 #6. 将PCA降维后的数据送入K均值模型
34 kmeans.fit(reduced_data)
35 
36 #7. 打印图片背景
37 h = .02 #越小越精细,越大越粗糙
38 # 计算边界用于分配颜色
39 x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
40 y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
41 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
42 # 使用K-means模型进行聚类
43 Z = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])
44 # 显示颜色底部图片
45 Z = Z.reshape(xx.shape)
46 plt.figure(1)
47 plt.clf()
48 plt.imshow(Z, interpolation='nearest',
49            extent=(xx.min(), xx.max(), yy.min(), yy.max()),
50            cmap=plt.cm.Paired,
51            aspect='auto', origin='lower')
52 
53 #8. 打印数据点
54 plt.plot(reduced_data[:, 0], reduced_data[:, 1], 'k.', markersize=2)
55 
56 #9. 打印聚类中心
57 centroids = kmeans.cluster_centers_#shape(10,2)
58 plt.scatter(centroids[:, 0], centroids[:, 1],
59             marker='x', s=169, linewidths=3,
60             color='w', zorder=10)
61 
62 #10. 显示图像
63 plt.xlim(x_min, x_max)
64 plt.ylim(y_min, y_max)
65 plt.xticks(())
66 plt.yticks(())
67 plt.show()

注:本文代码搬运自sk-learn官方示例:https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html#sphx-glr-auto-examples-cluster-plot-kmeans-digits-py,对代码中的注释进行了翻译。

posted on 2019-07-19 20:08  Sempron2800+  阅读(347)  评论(0编辑  收藏  举报