鸢尾花分类

from sklearn import datasets

import matplotlib.pyplot as plt

import numpy as np

from sklearn import tree

# Iris数据集是常用的分类实验数据集,

# 由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,

# 是一类多重变量分析的数据集。数据集包含150个数据集,

# 分为3类,每类50个数据,每个数据包含4个属性。

# 可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

 

#载入数据集

iris = datasets.load_iris()

iris_data=iris['data']

iris_label=iris['target']

iris_target_name=iris['target_names']

X=np.array(iris_data)

Y=np.array(iris_label)

print(X)

#训练

clf=tree.DecisionTreeClassifier(max_depth=3)

clf.fit(X,Y)

 

#这里预测当前输入的值的所属分类

print('类别是',iris_target_name[clf.predict([[12,1,-1,10]])[0]])

 

import numpy as np

import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D

 

from sklearn.cluster import KMeans

from sklearn import datasets

 

np.random.seed(5)

 

centers = [[1, 1], [-1, -1], [1, -1]]

iris = datasets.load_iris()

X = iris.data

y = iris.target

 

estimators = {'k_means_iris_3': KMeans(n_clusters=3),

              'k_means_iris_8': KMeans(n_clusters=8),

              'k_means_iris_bad_init': KMeans(n_clusters=3, n_init=1,

                                              init='random')}

 

fignum = 1

for name, est in estimators.items():

#创建初始图像

fig = plt.figure(fignum, figsize=(4, 3))

#清除figure坐标轴

plt.clf()

# 创建一个新的matplotlib.figure.Figure并为其添加一个类型为Axes3D的新轴

    ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)

 

    plt.cla()

    est.fit(X)

    labels = est.labels_

 

    ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=labels.astype(np.float))

#设置刻度标签的文本值

    ax.w_xaxis.set_ticklabels([])

    ax.w_yaxis.set_ticklabels([])

    ax.w_zaxis.set_ticklabels([])

    ax.set_xlabel('Petal width')

    ax.set_ylabel('Sepal length')

    ax.set_zlabel('Petal length')

    fignum = fignum + 1

 

# Plot the ground truth

fig = plt.figure(fignum, figsize=(4, 3))

plt.clf()

ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)

 

plt.cla()

 

for name, label in [('Setosa', 0),

                    ('Versicolour', 1),

                    ('Virginica', 2)]:

    ax.text3D(X[y == label, 3].mean(),

              X[y == label, 0].mean() + 1.5,

              X[y == label, 2].mean(), name,

              horizontalalignment='center',

              bbox=dict(alpha=.5, edgecolor='w', facecolor='w'))

# Reorder the labels to have colors matching the cluster results

y = np.choose(y, [1, 2, 0]).astype(np.float)

#以x为自变量y为因变量绘制x,y为坐标的点

ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=y)

 

ax.w_xaxis.set_ticklabels([])

ax.w_yaxis.set_ticklabels([])

ax.w_zaxis.set_ticklabels([])

#设置标签

ax.set_xlabel('Petal width')

ax.set_ylabel('Sepal length')

ax.set_zlabel('Petal length')

plt.show()

posted @ 2021-05-10 21:26  溜了溜  阅读(247)  评论(0编辑  收藏  举报