决策树介绍及算法
定义:
决策树是一种分而治之(Divide and Conquer)的决策过程。一个困难的预测问题, 通过树的分支节点, 被划分成两个或多个较为简单的子集,从结构上划分为不同的子问题。将依规则分割数据集的过程不断递归下去(Recursive Partitioning)。随着树的深度不断增加,分支节点的子集越来越小,所需要提的问题数也逐渐简化。当分支节点的深度或者问题的简单程度满足一定的停止规则(Stopping Rule)时, 该分支节点会停止劈分,此为自上而下的停止阈值(Cutoff Threshold)法;有些决策树也使用自下而上的剪枝(Pruning)法。
组成部分
分支节点
正如名称所指, 分支节点决定输入数据进入哪一个分支。每个分支节点对应一个分支函数(劈分函数),将不同的预测变量的值域映射到有限,离散的分支上。
根节点
根节点是一个特殊的分支节点,它是决策树的起点。
对于决策树来说,所有节点的分类或者回归目标都要在根节点已经定义好了。如果决策树的目标变量是离散的(序数型或者是列名型变量),则称它为分类树(Classification Tree);如果目标变量是连续的(区间型变量),则称它为回归树(Regression Tree)。
叶节点
叶节点存储了决策树的输出。对于分类问题,所有类别的后验概率都存储在叶节点,观测走过了全树从上到下的某一条路径(决策过程)之后会根据叶子节点给出一个“观测属于哪一类”的预报;对于回归问题,叶子结点上存储了训练集目标变量的中位数,不同观测走过决策路径后如果到达了相同的叶子结点,则对它们给出相同预报。
训练
一棵决策树由分支节点(树的结构)和叶节点(树的输出)组成. 决策树的训练的目标是通过最小化某种形式的损失函数或者经验风险, 来确定每个分支函数的参数,以及叶节点的输出.
决策树自上而下的循环分支学习(Recursive Regression)采用了贪心算法。每个分支节点只关心自己的目标函数。具体来说, 给定一个分支节点, 以及落在该节点上对应样本的观测(包含自变量与目标变量),选择某个(一次选择一个变量的方法很常见)或某些预测变量,也许会经过一步对变量的离散化(对于连续自变量而言),经过搜索不同形式的分叉函数且得到一个最优解(最优的含义是特定准则下收益最高或损失最小)。这个分支过程, 从根节点开始, 递归进行, 不断产生新的分支, 直到满足结束准则时停止。整个过程和树的分支生长非常相似。
测试(预测)
下图提供了一个简单的例子说明决策树的测试过程. 测试样本为
- 进入根节点, 分支函数 小于0, 进入左分支
- 分支函数 大于0, 进入右分支
- 分支函数 大于0, 进入右分支, 已经到达叶节点
- 假设是三分类问题, 该叶节点上所有类别的后验概率是 (0.1, 0.7, 0.2), 那么该决策树预测输入样本属于第二类.
可采用ID3算法思路:如果以某种特种特征来划分数据集,会导致数据集发生最大程度的改变,那么就使用这种特征值来划分。
假设现在创建了一个数据集,代码如下:本例为ID3代码,其他的如c4.5算法,cart算法之后会补上。
1 def createDataSet(): 2 dataSet = [[1,1,'yes'], 3 [1,1,'yes'], 4 [1,0,'no'], 5 [0,1,'no'], 6 [0,1,'no']] 7 labels = ['ct1','ct2'] 8 return dataSet, labels
计算数据集的香农熵代码:
from math import log def calcShannonEnt(dataSet): numEntries = len(dataset) //求出数据有多少个对象,即有多少行,计算实例总数 labelCounts = {} for featVec in dataset: //对数据集逐行求最后一类的数据,并将统计最后一列数据的数目 currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): //创建一个字典,键值是最后一列的数值 labelCountspcurrentLabel] = 0 //当前键值不存在,则扩展字典并将此键值加入字典 labelCounts[currentLabel] += 1 //每个键值都记录了当前类别出现的次数 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries //使用统计出的最后一列的数据来计算所有类标签出现的概率 shannonEnt -= prob * log(prob,2) //下面的公式 return shannonEnt
H=-∑p(xi)log(2,p(xi)) (i=1,2,..n)
熵越高,则混合的数据越多,数据的不纯度越大。得到熵,就可以按照获取最大信息增益的方法来划分数据集。
划分数据集的代码
1 def splitDataSet(dataSet, axis, values): 2 retDataSet = [] 3 for featVec in dataSet: 4 if featVec[axis] == value: //判断axis列的值是否为value 5 reducedFeatVec = featVec[:axis] //[:axis]表示前axis行,即若axis为2,就是取featVec的前axis行 6 reducedFeatVec.extend(featVec[axis+1:]) //[axis+1:]表示从跳过axis+1行,取接下来的数据 7 retDataSet.append(reducedFeatVec) 8 return retDataSet
执行完上面的代码,数据就会将符合值判定的行取出来,然后将这些行里用来判定值的列去除,剩下的数据就是划分完的数据集
接下来的代码是选择最好的数据集划分方式
1 def chooseBestFeatureToSplit(dataSet): 2 numFeatures = len(dataSet[0]) - 1 3 baseEntropy = calcShannonEnt(dataSet) 4 bestInfoGain = 0.0 5 bestFeature = -1 6 for i in range(numFeatures): 7 featList = [example[i] for example in dataSet] 8 uniqueVals = set(featList) //创建了一个列表,里面的元素是dataSet所有的元素,但不重复 9 newEntropy = 0.0 10 for value in uniqueVals: 11 subDataSet = splitDataSet(dataSet, i, value) 12 prob = len(subDataSet)/float(len(dataSet)) 13 newEntropy += prob * calcShannonEnt(subDataSet) //计算按每个数据特征来划分的数据集的熵 14 infoGain = baseEntropy - newEntropy 15 if (infoGain > bestInfoGain): 16 bestInfoGain = infoGain 17 bestFeature = i //判断出哪种划分方式得到最大的信息增益,且得出该划分方式所选择的特征 18 return bestFeature
在第一次划分之后,数据将被传递到下一个节点,在此节点上可以再次划分数据,因此可以用递归来处理数据集。递归结束的条件是,程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据都属于叶子节点的分类。下面的代码用来判断递归何时结束。
1 import operator 2 def majorityCnt(classList): //参数classList在下面创建树的代码里,是每一行的最后一个特征 3 classCount = {} 4 for vote in classList: //将特征里的元素添加到新建的字典里作为键值,并统计该键值出现次数 5 if vote not in classCount.keys(): classCount[vote] = 0 6 classCount[vote] += 1 7 sortedClassCount = sorted(classCount.iteritems(), key=operater.itemgetter(1), reverse=True) 8 return sortedClassCount[0][0] //返回出现次数最多的键值
下面是创建树的代码,该代码调用了上面多个模块。
1 def createTree(dataSet, labels): 2 classList = [example[-1] for example in dataSet] 3 if classList.count(classList[0]) == len(classList): 4 return classList[0] 5 if len(dataSet[0]) == 1: 6 return majorityCnt(classList) 7 bestFeat = chooseBestFeatureToSplit(dataSet) 8 bestFeatLabel = labels[bestFeat] 9 myTree = {bestFeatLabel:{}} 10 del(labels[bestFeat]) 11 featValues = [example[bestFeat] for example in dataSet] 12 uniqueVals = set(featValues) 13 for value in uniqueVals: 14 subLabels = labels[:] 15 myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataSet, bestFeat, value), subLabels) 16 return myTree