Loading

对鸢尾花数据进行分类的思路

对鸢尾花数据进行分类

1 数据集处理

加载数据集,IRIS 数据集在 sklearn 模块中已经提供
from sklearn import datasets
iris = datasets.load_iris()
iris_feature = iris.data
iris_target = iris.target
将150个样本分割为90个训练集和60个测试集
feature_train, feature_test, target_train, target_test = train_test_split(iris_feature, iris_target, test_size=0.4,random_state=40)

2 决策树分类(实现)

这里选用CART算法。算法从根节点开始,用训练集递归构建分类树。在决策树的构建中,有时会造成决策树分支过多,这是就需要去掉一些分支,降低过度拟合。通过决策树的复杂度来避免过度拟合的过程称为剪枝。

创建决策树:
步骤1:选择GiniIndex最小的维度作为分割特征。(GiniIndex计算方式见PPT)
步骤2:如果数据集不能再分割,即GiniIndex为0或只有一个数据,该数据集作为一个叶子节点。
步骤3:对数据集进行二分割
步骤4:对分割的数据集1重复步骤1、2、3,创建true子树
步骤5:对分割的数据集2重复步骤1、2、3,创建false子树

明显的递归算法。

剪枝:
需要从训练集生成一棵完整的决策树,然后自底向上对非叶子节点进行考察。判断是否将该节点对应的子树替换成叶节点。当节点的gain小于给定的 mini Gain时则合并这两个节点.。

测试:
通过对测试集的预测来验证准确性。对于不同的划分方式,即选取不同随机数种子且保持90:60的比例,训练集准确率为:100 %,测试集准确率为:91.67 %

3 SVM分类(调库)

支持向量机的基本模型是定义在特征空间上的间隔最大的线性分类器,即求一个分离超平面,这个超平面使得离它最近的点能够最远。

通过点到超平面的距离公式,找到最大间隔的优化模型,即找每个超平面对应着一个间隔,就是要找出所有间隔中最大的那个值对应的超平面。

需要注意,如果数据集中存在噪点的话,那么在求超平的时候就会出现很大问题,因此需要引入一个松弛变量ξ来允许一些数据可以处于分隔面错误的一侧。

以上讨论的都是在线性可分情况进行讨论的,但是实际问题中给出的数据并不是都是线性可分的。需要使用核函数解决这个问题。
在经过训练集测试后,训练集准确率: 0.9444,测试集准确率:0.9833

4 BPNN分类(实现)

BP神经网络是一种多层前馈神经网络,该网络的主要特点是信号前向传递,误差反向传播。在前向传播的过程中,输入信号从输入层经隐含层处理,直至输出层。如果输出层得不到期望输出,则转入反向传播,根据预测误差调整网络权值和阈值,从而使BP神经网络预测输出不断逼近期望输出。

对于鸢尾花数据集,搭建一个输入层(4)-隐含层(10)-输出层(1)的神经网络。
4:feature维度数;10:隐含层神经元数量;1:target维度数。

实现并训练BP神经网络,主要有以下几个步骤:
步骤1:初始化神经网络参数,主要是确定结构和初始化权重。
步骤2:正向传播计算。先对数据进行归一化处理,使其在0-1范围内,中间层和输出层激活函数都为sigmoid。
步骤3:成本函数计算,使用一个样本的期望输出和实际输出的误差的平方用来定义损失函数
步骤4:反向传播计算,修正权重参数,以提高拟合效果。

测试:
用训练集通过上述步骤得到一个神经网络模型,然后代入测试集数据,我们将0-1分为3份(这里用0-0.2-0.8-1),来作为三种类别,训练集准确率为:95.555556 %
测试集准确率为:93.333333 %。

5 KNN分类(调库)

KNN算法,即K近邻算法。选取训练集中离该数据最近的 k 个点,它们中的大多数属于哪个类别,则该新数据就属于哪个类别。

k 的选择是一个超参数的选择问题,需要通过调整 K 的值确定最好的 K,最好选奇数,否则会出现同票。可以通过交叉验证法确定模型的最佳 k 值。

度量距离的方式,一般为 Lp 距离:p = 1 时,为曼哈顿距离:p = 2 时,为欧式距离:欧式距离是我们最常用的计算距离的方式。

分类的规则,采取多数表决的原则,即由输入实例的 k 个近邻的训练实例中的多数类决定输入实例的类。

需要注意:

  1. 不同特征有不同的量纲,必要时需进行特征归一化处理
  2. kNN 的时间复杂度为O(DNN),D 是维度数,N 是样本数,这样,在特征空间很大和训练数据很大时,kNN 的训练时间会非常慢。

在训练模型后,训练集准确率: 0.956,测试集准确率: 0.917。

posted @ 2020-11-17 16:20  iterationjia  阅读(1021)  评论(0编辑  收藏  举报