探索sklearn | 鸢尾花数据集

1 鸢尾花数据集背景

鸢尾花数据集是原则20世纪30年代的经典数据集。它是用统计进行分类的鼻祖。

sklearn包不仅囊括很多机器学习的算法,也自带了许多经典的数据集,鸢尾花数据集就是其中之一。

导入的方法很简单,不过我比较好奇它是如何来存储这些数据的,于是我决定去背后看一看

from sklearn.datasets import load_iris

data = load_iris()

 找到sklearn包的路径,发现包可不少,不过现在扔在一边,以后再来探索,我现在要找到是datasets文件夹。

文件夹里没有找到load_iris()这个函数在哪,只是在__init__文件里,发现了这么一行

from .base import load_iris

 

2 数据的内容

不出我料数据没有存储在程序文件里,而是用csv格式保存着,单独放在了data文件夹里

150,4,setosa,versicolor,virginica
5.1,3.5,1.4,0.2,0 #花萼长度,花萼宽度,花瓣长度,花瓣宽度
4.9,3.0,1.4,0.2,0
4.7,3.2,1.3,0.2,0
4.6,3.1,1.5,0.2,0
5.0,3.6,1.4,0.2,0

 第一行首先记录了样本数目150,特征数目4

现在是时候来详细介绍一下数据了:

数据包含三种鸢尾花的四个特征,分别是花萼长度(cm)、花萼宽度(cm)、花瓣长度(cm)、花瓣宽度(cm),这些形态特征在过去被用来识别物种。时至今日,我们已经可以通过基因签名来识别这些分类了。

三种鸢尾花分别是

山鸢尾花(Iris Setosa)、

变色鸢尾花(Iris Versicolor)和

维吉尼亚鸢尾花(Iris Virginica)

 

3 数据可视化

鸢尾花数据集只有150个样本,每个样本只有4个特征,容易将其可视化

上面加载的data变量是一个类似字典的类型,是数据信息的集合,它像字典一样通过键值对来组织信息

值既可以通过data['target']也可以通过data.target来获取,很明显这说明data并不是字典类型

data.keys()
>>['target_names', 'data', 'target', 'DESCR', 'feature_names']
feature = data['data'] #为numpy.ndarray类型
feature.shape #矩阵的行数和劣势
>> (150L, 4L)
target = data['target']
target.shape
>>(150L,)

 

 四个特征是不可能同时在平面图里画出来的,只得运用我们的聪明才智,把它两两一组

def plot_iris_projection(x_index, y_index):
    for t,marker,c in zip(xrange(3),'>ox', 'rgb'):
        plt.scatter(data[target==t,x_index],
                    data[target==t,y_index],
                    marker=marker,c=c)
        plt.xlabel(feature_names[x_index])
        plt.ylabel(feature_names[y_index])

pairs = [(0,1),(0,2),(0,3),(1,2),(1,3),(2,3)] for i,(x_index,y_index) in enumerate(pairs): plt.subplot(2,3,i) plot_iris_projection(x_index, y_index) plt.show()

 

 

不难发现的是,不论在那两个特征下,山鸢尾花都能很好的和其他两种鸢尾花区分,但是另外两种鸢尾花的特征比较焦灼,如果只有这四个特征,有时人都难以区分。

数据可视化最高只能是三维,matplotlib也能胜任此工作

from mpl_toolkits.mplot3d import Axes3D

def plot_iris_projection3d(x_index, y_index, z_index):
    fig = plt.figure()
    ax = fig.add_subplot(111,projection='3d')
    for t,marker,c in zip(xrange(3),'>ox', 'rgb'):
        ax.scatter(data[target==t,x_index],
                    data[target==t,y_index],
                    data[target==t,z_index],
                    marker=marker,c=c)
        ax.set_xlabel(feature_names[x_index])
        ax.set_ylabel(feature_names[y_index])
        ax.set_zlabel(feature_names[z_index])
        
plot_iris_projection3d(1, 2, 3)
plt.show()

 

posted on 2015-02-04 12:18  meelo  阅读(12228)  评论(1编辑  收藏  举报