k-mean鸢尾花分类
import numpy as np # 矩阵计算函数库 import matplotlib.pyplot as plt # 可视化图像 from mpl_toolkits.mplot3d import Axes3D # 3维图 from sklearn.cluster import KMeans # KMeans聚类算法 from sklearn import datasets # 鸢尾花数据集 np.random.seed(5) # 设置随机种子,5个数,用于K-means聚类算法的初始化 centers = [[1, 1], [-1, -1], [1, -1]] # 聚类中心 iris = datasets.load_iris() # 获取数据集 X = iris.data # 训练所需的数据集 y = iris.target # 数据集对应的分类标签,属于监督学习 estimators = {'k_means_iris_3': KMeans(n_clusters=3), 'k_means_iris_8': KMeans(n_clusters=8), 'k_means_iris_bad_init': KMeans(n_clusters=3, n_init=1, init='random')} # 设置K-means的参数,n_clusters是需要计算出的集群数,n_init使用不同centroid seeds运行K-means的时间,init是初始化方法 fignum = 1 for name, est in estimators.items(): fig = plt.figure(fignum, figsize=(4, 3)) # figsize指定图像的纵向高度和横向宽度 plt.clf() # 清空当前图像操作,此处可以不加 ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134) # 返回3D图形对象 plt.cla() # 清空当前坐标操作,此处可以不加 est.fit(X) # 用数据对算法进行拟合操作 labels = est.labels_ # 得到每一数据点的分类结果 # 绘制散点图 ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=labels.astype(np.float)) # scatter是绘制散点图的函数,前面3个参数对应数据在x,y,z轴的坐标,c代表色彩颜色序列 # 设置x,y,z轴的刻度标签,[]代表不描绘刻度 ax.w_xaxis.set_ticklabels([]) ax.w_yaxis.set_ticklabels([]) ax.w_zaxis.set_ticklabels([]) # 设置x,y,z轴的标签 ax.set_xlabel('Petal width') ax.set_ylabel('Sepal length') ax.set_zlabel('Petal length') fignum = fignum + 1 # Plot the ground truth fig = plt.figure(fignum, figsize=(4, 3)) plt.clf() ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134) plt.cla() for name, label in [('Setosa', 0), ('Versicolour', 1), ('Virginica', 2)]: # 在数据集中心绘制 分类标签的名字 ax.text3D(X[y == label, 3].mean(), X[y == label, 0].mean() + 1.5, X[y == label, 2].mean(), name, horizontalalignment='center', # center代表text向中间水平对齐 bbox=dict(alpha=.5, edgecolor='w', facecolor='w')) # bbox用于设置ext背景框 alpha为透明度,edgecolor为边框颜色(w为white之意),facecolor为背景框内部颜色 # Reorder the labels to have colors matching the cluster results y = np.choose(y, [1, 2, 0]).astype(np.float) ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=y) # 绘制散点图 # 设置x,y,z轴的刻度标签,[]代表不描绘刻度 ax.w_xaxis.set_ticklabels([]) ax.w_yaxis.set_ticklabels([]) ax.w_zaxis.set_ticklabels([]) # 设置x,y,z轴的标签 ax.set_xlabel('Petal width') ax.set_ylabel('Sepal length') ax.set_zlabel('Petal length') plt.show() # 显示图像