机器学习-决策树算法
date: 2019-04-07 00:48
一、基本流程
- 1、初始化属性集合和数据集合
- 2、从数据集中选择最优划分属性,并以该属性为当前决策节点
- 3、更新数据集合和属性集合(删除掉上一步中使用的属性,并按照属性值来划分不同分支的数据集合)
- 4、依次对每种取值情况下的子集重复第2步
- 5、若子集只包含单一属性,则为分支为叶子节点,根据其属性值标记。
- 6、完成所有属性集合的划分
决策树生成过程是一个递归的过程,以下三种情况会导致递归停止:
- 1.当前节点包含的样本全部属于同一类别,无需划分
- 2.当前属性集为空,或是所有样本在所有属性上取值相同,无法划分
- 3.当前节点包含的样本集合为空,不能划分
从决策树生成的整个过程来看,其中最核心的应该是第2步:选择最优划分属性,这也是不同决策树算法之间的区别,下文讲解如何选择的。
二、划分选择
随着划分过程的不断进行,决策树的分支节点所包含的样本尽可能属于同一类别,即节点的“纯度”(purity)越来越高。
1.信息增益(ID3决策树算法)
信息熵(information entropy):度量样本集合纯度最常用的一种指标。
对于数据集D:
其中$$p_{k}$$表示第k类样本在D中所占的比例。Ent(D)的值越小,在D的纯度越高。
信息增益(information gain):属性划分数据集前后信息熵的差值。
假定离散属性a有V个可能的取值$${a{1},a,...,a{V}}$$,若使用a来对样本集合D进行划分,则会产生V个分支节点,其中第v个分支节点包含了D中所有在属性取值为$$a$$的样本,记为$$D^{v}$$。那么用属性a划分数据集D所获得的信息增益为:
信息增益越大,意味着用a属性来进行划分所获得的“纯度提升”越大。
著名的ID3决策树算法就是以信息增益为准则来选择属性划分的。
缺点:信息增益准则对可取值数据较多的属性有所偏好。
2.增益率(C4.5决策树算法)
增益率(gain ratio)的定义为:
其中
IV(a)称为a的“固有值”。属性a的可能取值数目越多(即V越大),则IV(a)的值通常会越大。
增益率对可取值数目较少的属性有偏好。因此C4.5算法使用了一个启发式:先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。
3.基尼指数(CART决策树算法)
CART决策树使用基尼指数来划分属性。
对于数据集D:
Gini(D)反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率。因此,Gini(D)越小,则数据集D的纯度越高。
属性a的基尼指数定义为:
于是在候选属性集合A中,选择那个使得划分基尼指数最小的属性作为划分属性,即$$a_{*}=\underset{a\epsilon A}{argmin\ Gini_index(D,a)}$$
三、剪枝处理
预剪枝是指在决策树生成过程中,对每一个节点在划分前先进行估计,若当前节点的划分不能带来决策树泛化性能提升,则停止划分并将当前节点标记为叶节点。
后剪枝是先从训练集生成一棵完整的决策树,然后自底向上地对非叶节点进行考察,若将该节点对应的子树替换为叶节点能带来决策树泛化性能提升,则将该子树替换为叶节点。
效果对比:
后剪枝决策树的欠拟合风险很小,泛化性能往往优于预剪枝决策树。但是后剪枝训练时间开销比未剪枝和预剪枝决策树都要大的多。
四、推到过程
五、算法实现
手写代码实现算法便于理解算法本质:决策树
六、sklearn库中决策树的使用方法
scikit-learn 使用 CART 算法的优化版本。
import os
import time
import numpy as np
from sklearn import tree
from sklearn.externals.six import StringIO
from sklearn.model_selection import train_test_split
print('Step 1.Loading data...')
X_train,X_test,Y_train,Y_test = ...
print('---Loading and splitting completed.')
print('Step 2.Training...')
startTime = time.time()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train,Y_train)
print('---Training Completed.Took %f s.'%(time.time()-startTime))
print('Step 3.Testing...')
Y_predict = clf.predict(X_test)
matchCount = 0
for i in range(len(Y_predict)):
if Y_predict[i] == Y_test[i]:
matchCount += 1
accuracy = float(matchCount/len(Y_predict))
print('---Testing completed.Accuracy: %.3f%%'%(accuracy*100))
参考:
- 《机器学习》周志华
- Scikit-learn官方文档
- Scikit-learn 0.19.x 中文文档