机器学习实战笔记—决策树(代码讲解)
最近在学习机器学习实战,里面讲的算法虽然易于理解,但还是需要反复领悟,算法的思想与编写代码的技巧都是学习的重点。每章节涉及到其他的知识也都需要再查阅相关资料进行补充。本篇总结一下决策树,方便和大家一起交流。
序号 | 不浮出水面是否可以生存 | 是否有脚蹼 | 属于鱼类 |
1 | 是 | 是 | 是 |
2 | 是 | 是 | 是 |
3 | 是 | 否 | 否 |
4 | 否 | 是 | 否 |
5 | 否 | 是 | 否 |
在此,特征是:不浮出水面是否可以生存、是否有脚蹼。特征下的属性值都有两类:是、否。所属类别的属性值也是两类:是、否。
构建决策树后的流程图形式是:
长方形表示判断模块(decision block),椭圆形表示终止模块(terminating block),从判断模块引出的箭头称为分支,可以达到另一个判断模块或者终止模块。当然,已给定数据集下会有很多树的形式。所以,从最开始提出问题到最后构建决策树有以下几个思考的点:
(1)使用哪种数据结构来存储决策树。
(2)如何决定特征在树中的节点位置,也就是考虑当前数据集中哪个特征在分类时起到决定性作用。
(3)如何构建决策树。
(4)决策树的可视化表示。
二、存储决策树的数据结构
在此,我们使用python中的字典类型,来存储树。比如以上示例对应的字典类型为:{'No Surfacing':{0:'No',1:{'Flippers':{0:'No',1:'Yes'}}}},决策树分支上的值(属性)在这里 'No' 表示为 0,'Yes' 表示为 1,叶节点(所属类别)表示仍为'No','Yes'。这种字典表示形式稍微有些复杂,但还是比较形象,容易理解的。
三、决策树非叶节点的位置
1、概述
这节主要讲述,怎么决定决策树非叶节点的位置,也就是考虑在当前数据集中哪个特征对分类起到决定性作用。对分类结果有影响的几个特征,我们认为它们的影响程度是有差异的,比如,判定类别是不是鸟,假定特征有:会不会飞 和 是否卵生。那么 会不会飞 这个特征对大部分鸟来说属性都是Yes(当然除了企鹅和鸵鸟以外),而卵生的不一定是鸟,我们就说这个特征是比较有决定性的。
首先选出原数据集中最有决定性的特征,通过这个特征划分数据集,得到数据子集。这些数据子集会分布在第一个决策特征的所有分支上。如果某个分支下的数据属于同一类型就无需再分,如果没有数据子集中的数据不属于同一类,则继续按照原数据集选决定性特征的方式进行划分数据子集。直到所有同类型的数据在一个子集中或者数据子集已无可用来划分的特征。
2、熵的计算
那么怎么选择具有决定性的特征呢?为了找到决定性特征,划分出最好的结果,我们必须评估每个特征。
直面上看,拿上面的例子来说,特征1 不浮出水面是否可以生存 划分数据分类,将属性值为'是'的序号1、2、3归为一类 对应所属类别列表为['Yes','Yes','No'],属性值为'否'的4、5归为另一类 所属类别列表为['No','No']。而用特征2 是否有脚蹼 进行划分,得到 属性值为'是' 的序号1、2、4、5归为一类 对应类别列表为['Yes','Yes','No','No'],属性值是 '否'的3 归为另一类 类别列表['No']。这样比较所划分的类别列表,可以得出特征1 是当前最有决定性的特征。
理论上讲,引入熵和信息增益的概念。其实熵就是混乱的度量,熵越高,混合的数据越多。所以使用每个特征划分数据集,使得划分后数据子集的熵最小(也就是划分后混乱程度最小)的特征,该特征就是最具有决定性的特征。下面讲熵的具体计算。
熵定义为信息的期望值,先看看信息的计算。如果 xi 类别被划分在含有x1,x2...xn 类的集合中(混合类), xi 的信息定义为如下:
其中,p(xi)是选择该分类的概率。可以看出 p(xi) 越大,信息越小。也就是集合中类的种类越单一,信息也就越小。
熵的计算如下,集合中 所有类别对应的信息*概率 的和即为信息的期望:
使用python计算信息熵,给定数据集dataSet,数据类型为列表。以上示例可以表示为[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]。
1 def calcShannonEnt(dataSet): 2 #函数功能:计算给定数据集信息熵 3 numEntries = len(dataSet) 4 labelCounts = {} #字典存储 键:类别,值:类别的个数 5 for featVec in dataSet: 6 currentLabel = featVec[-1] 7 if currentLabel not in labelCounts.keys(): 8 labelCounts[currentLabel] = 0 9 labelCounts[currentLabel] += 1 10 11 #计算 所有类别的信息*概率 之和 12 shannonEnt = 0.0 13 for key in labelCounts: 14 prob = float(labelCounts[key])/numEntries 15 shannonEnt -= prob * log(prob,2) 16 return shannonEnt
3、划分数据集
上小节知道了怎么去计算熵,也就是度量集合的混乱程度。信息增益是熵的减少,也就是混乱程度的减少。那么对每个特征划分数据集,分别计算原数据集和划分后的数据子集的熵之差,计算信息增益。最后使用信息增益来选择最好的特征划分,当然,信息增益越大使得划分效果越好。
这块内容分为两个步骤:一是按照给定特征值来划分数据集,二是找到最有决定性的特征,返回。
步骤二需使用步骤一的函数来进行计算划分前后的信息增益。
步骤一代码:
1 #函数功能:按照给定的特征值,来进行划分数据集 2 def splitDataSet(dataSet, axis, value): 3 retDataSet = [] 4 for featVec in dataSet: 5 if featVec[axis] == value: 6 reducedFeatVec = featVec[:axis] 7 reducedFeatVec.extend(featVec[axis+1:]) 8 9 #reducedFeatVec 存储 去除给定特征值后的数据 10 retDataSet.append(reducedFeatVec) 11 return retDataSet
步骤二代码:
1 def chooseBestFeatureToSplit(dataSet): 2 numFeatures = len(dataSet[0]) - 1 #计算每条数据共有多少个特征 3 baseEntropy = calcShannonEnt(dataSet) #计算原数据集的熵(混乱程度) 4 bestInfoGain = 0.0 #初始化 最好的信息增益 5 bestFeature = -1 6 for i in range(numFeatures): #遍历所有特征,最后计算每个特征的信息增益,选取最好的 7 featList = [example[i] for example in dataSet] #创建对应特征的属性值列表 8 uniqueVals = set(featList) #得到当前特征下的不同的属性值集合 9 newEntropy = 0.0 #初始化熵 10 11 #计算当前特征划分后的数据子集的熵 12 for value in uniqueVals: 13 subDataSet = splitDataSet(dataSet, i, value) 14 prob = len(subDataSet)/float(len(dataSet)) 15 newEntropy += prob * calcShannonEnt(subDataSet) 16 17 infoGain = baseEntropy - newEntropy #计算当前特征划分后的数据子集的信息增益 18 19 if (infoGain > bestInfoGain): #计算的信息增益如果比现有的大,则重新赋值。找到最好的信息增益 20 bestInfoGain = infoGain 21 bestFeature = i 22 return bestFeature #返回最有决定性的特征标识
到此为止,我们学习了如何度量数据集的混乱程度,如何计算熵,如何找到最有决定性的特征,有效划分数据集。接下来结合之前的函数讲如何构造决策树。
四、递归构建决策树
上面我们也介绍了决策树的数据类型是字典形式,那么在此解决的问题如何递归创建该字典。
基本原理是:给定原始数据集,基于最好的特征划分数据集,根节点为该特征,分支为对应的不同的属性值,有多少种属性值,就有多少个分支。不同分支指向的节点是去除分支上属性值后的数据子集。节点上的数据子集可以再次依照相同的方式被划分。所以此处可以使用递归的思想。递归结束的条件:程序遍历完所有可用于划分的特征,或者每个分支下所有实例属于相同的分类。
数据从根节点出发,按照规则可以到达叶子节点,也必然属于该分类。
代码如下:
1 def createTree(dataSet,labels): 2 classList = [example[-1] for example in dataSet] 3 if classList.count(classList[0]) == len(classList): 4 return classList[0] #结束递归的条件:数据集类别属于同一类 5 if len(dataSet[0]) == 1: #结束递归的条件:已没有用于划分的特征 6 return majorityCnt(classList) 7 8 bestFeat = chooseBestFeatureToSplit(dataSet) #选择最好的特征 9 bestFeatLabel = labels[bestFeat] 10 11 myTree = {bestFeatLabel:{}} #使用最好的特征初始化决策树 12 del(labels[bestFeat]) 13 14 featValues = [example[bestFeat] for example in dataSet] 15 uniqueVals = set(featValues) #得到最好特征对应的不同属性值集合 16 for value in uniqueVals: 17 subLabels = labels[:] 18 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) #使用划分后的数据子集 和 剩下的特征列表 递归构建字典 19 return myTree
五、树的可视化表示
1、构建决策树的总体思路
上面已经知道怎么创建一个用于存储决策树的字典了。已给数据集,找到当前最具决定性的特征划分数据集得到数据子集。该特征作为当前树的根节点,对应的不同属性值作为分支,有可能不止两个属性值,所以节点会超过两个分支。每个分支对应一个数据子集,数据子集再按照相同方式进行划分。直到没有可用特征或者数据子集的类别都相同。再找最具决定性的特征时就涉及到了熵和信息增益的定义。
2、利用Matplotlib注释绘制决策树
在此,使用固定的基坐标轴(x、y轴的绘制范围都是从0到1)动态绘制树,无论树的大小都可以进行缩放,使树位于坐标轴的中心位置。
那么接下来就是考虑数学问题了,怎么去动态的适应基坐标轴,根据树的大小怎么去确定具体的节点位置。大致做法应该是每个非叶节点都位于属于它的叶节点的中间位置,所有叶节点在水平位置上距离相等。设该树所有叶节点水平位置上之间的距离 d=( 1/叶节点的个数)。
这部分代码如下,希望下面解释的部分结合代码看更容易理解。递归求出树的宽度(叶节点的个数)W和高度(层数)H 不多说,注释绘图部分 不多说,只解释下怎么求解树节点的位置。
以下只考虑x坐标的计算,y坐标比较容易。分两种情况讨论,一是非叶节点,二是叶节点。
对于非叶节点,每个非叶节点的x位置都可以分别用它左边最近的叶节点进行求解。设某非叶结点 A 有n个叶节点,它左边最近叶节点 a 位置为(xOff,yOff),那么A应位于属于它的所有叶节点的中间位置,即水平位置上离 a 的间隔是(n+1)/2。那么它的 x坐标=xOff+d*(n+1)/2。
对于叶节点,每个叶节点的x位置也可以用它左边最近的叶节点进行求解。左边最近的叶节点坐标位置(xOff,yOff),即 某叶节点x坐标是 xOff+d。
对于树的根节点,按照非叶结点同等对待,但可以知道它的坐标是(0.5,1),最初的xOff应该设为 0.5-d*(n+1)/2,又 n=1/d ,即xOff 初值为 -0.5d。
1 import matplotlib.pyplot as plt 2 3 decisionNode = dict(boxstyle="sawtooth", fc="0.8") 4 leafNode = dict(boxstyle="round4", fc="0.8") 5 arrow_args = dict(arrowstyle="<-") 6 7 #递归求解树的宽度 8 def getNumLeafs(myTree): 9 numLeafs = 0 10 firstStr = myTree.keys()[0] 11 secondDict = myTree[firstStr] 12 for key in secondDict.keys(): 13 if type(secondDict[key]).__name__=='dict': 14 numLeafs += getNumLeafs(secondDict[key]) 15 else: numLeafs +=1 16 return numLeafs 17 18 #递归求解树的深度 19 def getTreeDepth(myTree): 20 maxDepth = 0 21 firstStr = myTree.keys()[0] 22 secondDict = myTree[firstStr] 23 for key in secondDict.keys(): 24 if type(secondDict[key]).__name__=='dict': 25 thisDepth = 1 + getTreeDepth(secondDict[key]) 26 else: thisDepth = 1 27 if thisDepth > maxDepth: maxDepth = thisDepth 28 return maxDepth 29 30 #使用文本注解绘制树节点 31 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 32 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', 33 xytext=centerPt, textcoords='axes fraction', 34 va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) 35 36 #绘制分支上的值,计算父节点和子节点的中间位置,添加简单的文本信息 37 def plotMidText(cntrPt, parentPt, txtString): 38 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] 39 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 40 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) 41 42 #绘制树 43 def plotTree(myTree, parentPt, nodeTxt): 44 45 numLeafs = getNumLeafs(myTree) #得到当前树的宽度 46 depth = getTreeDepth(myTree) #得到当前树的深度 47 firstStr = myTree.keys()[0] #当前树根节点的描述信息 48 49 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #根据左边临近的叶节点坐标得到当前非叶结点的坐标(有解释) 50 plotMidText(cntrPt, parentPt, nodeTxt) #父节点和子节点的中间位置,添加简单的文本信息 51 plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制 当前非叶节点 52 secondDict = myTree[firstStr] 53 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 54 55 for key in secondDict.keys(): #遍历当前节点的子节点 56 if type(secondDict[key]).__name__=='dict': #如果子节点类型是非叶节点,递归画树 57 plotTree(secondDict[key],cntrPt,str(key)) 58 else: #如果子节点是叶节点,直接计算叶节点的坐标,绘制 59 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 60 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 61 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 62 63 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #递归完后需要回退到上层,绘制当前树根节点的其他分支节点。 64 65 #主函数,调用 绘制树 函数 66 def createPlot(inTree): 67 fig = plt.figure(1, facecolor='white') 68 fig.clf() 69 axprops = dict(xticks=[], yticks=[]) 70 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 71 plotTree.totalW = float(getNumLeafs(inTree)) 72 plotTree.totalD = float(getTreeDepth(inTree)) 73 74 #需要初始化虚拟的根节点左边最邻近的叶节点(虽然不存在此节点),将根节点等同于其他非叶节点计算坐标位置。递归需要。 75 plotTree.xOff = -0.5/plotTree.totalW 76 plotTree.yOff = 1.0 77 plotTree(inTree, (0.5,1.0), '') 78 plt.show()