决策树的python实现

决策树

算法优缺点:

  • 优点:计算复杂度不高,输出结果易于理解,对中间值缺失不敏感,可以处理不相关的特征数据

  • 缺点:可能会产生过度匹配的问题

  • 适用数据类型:数值型和标称型

算法思想:

1.决策树构造的整体思想:

决策树说白了就好像是if-else结构一样,它的结果就是你要生成这个一个可以从根开始不断判断选择到叶子节点的树,但是呢这里的if-else必然不会是让我们认为去设置的,我们要做的是提供一种方法,计算机可以根据这种方法得到我们所需要的决策树。这个方法的重点就在于如何从这么多的特征中选择出有价值的,并且按照最好的顺序由根到叶选择。完成了这个我们也就可以递归构造一个决策树了

2.信息增益

划分数据集的最大原则是将无序的数据变得更加有序。既然这又牵涉到信息的有序无序问题,自然要想到想弄的信息熵了。这里我们计算用的也是信息熵(另一种方法是基尼不纯度)。公式如下:

数据需要满足的要求:

1 数据必须是由列表元素组成的列表,而且所有的列白哦元素都要具有相同的数据长度
2 数据的最后一列或者每个实例的最后一个元素应是当前实例的类别标签

函数:

calcShannonEnt(dataSet)
计算数据集的香农熵,分两步,第一步计算频率,第二部根据公式计算香农熵
splitDataSet(dataSet, aixs, value)
划分数据集,将满足X[aixs]==value的值都划分到一起,返回一个划分好的集合(不包括用来划分的aixs属性,因为不需要)
chooseBestFeature(dataSet)
选择最好的属性进行划分,思路很简单就是对每个属性都划分下,看哪个好。这里使用到了一个set来选取列表中唯一的元素,这是一中很快的方法
majorityCnt(classList)
因为我们递归构建决策树是根据属性的消耗进行计算的,所以可能会存在最后属性用完了,但是分类还是没有算完,这时候就会采用多数表决的方式计算节点分类
createTree(dataSet, labels)
基于递归构建决策树。这里的label更多是对于分类特征的名字,为了更好看和后面的理解。

  1.  1 #coding=utf-8
     2 import operator
     3 from math import log
     4 import time
     5 
     6 def createDataSet():
     7     dataSet=[[1,1,'yes'],
     8             [1,1,'yes'],
     9             [1,0,'no'],
    10             [0,1,'no'],
    11             [0,1,'no']]
    12     labels = ['no surfaceing','flippers']
    13     return dataSet, labels
    14 
    15 #计算香农熵
    16 def calcShannonEnt(dataSet):
    17     numEntries = len(dataSet)
    18     labelCounts = {}
    19     for feaVec in dataSet:
    20         currentLabel = feaVec[-1]
    21         if currentLabel not in labelCounts:
    22             labelCounts[currentLabel] = 0
    23         labelCounts[currentLabel] += 1
    24     shannonEnt = 0.0
    25     for key in labelCounts:
    26         prob = float(labelCounts[key])/numEntries
    27         shannonEnt -= prob * log(prob, 2)
    28     return shannonEnt
    29 
    30 def splitDataSet(dataSet, axis, value):
    31     retDataSet = []
    32     for featVec in dataSet:
    33         if featVec[axis] == value:
    34             reducedFeatVec = featVec[:axis]
    35             reducedFeatVec.extend(featVec[axis+1:])
    36             retDataSet.append(reducedFeatVec)
    37     return retDataSet
    38     
    39 def chooseBestFeatureToSplit(dataSet):
    40     numFeatures = len(dataSet[0]) - 1#因为数据集的最后一项是标签
    41     baseEntropy = calcShannonEnt(dataSet)
    42     bestInfoGain = 0.0
    43     bestFeature = -1
    44     for i in range(numFeatures):
    45         featList = [example[i] for example in dataSet]
    46         uniqueVals = set(featList)
    47         newEntropy = 0.0
    48         for value in uniqueVals:
    49             subDataSet = splitDataSet(dataSet, i, value)
    50             prob = len(subDataSet) / float(len(dataSet))
    51             newEntropy += prob * calcShannonEnt(subDataSet)
    52         infoGain = baseEntropy -newEntropy
    53         if infoGain > bestInfoGain:
    54             bestInfoGain = infoGain
    55             bestFeature = i
    56     return bestFeature
    57             
    58 #因为我们递归构建决策树是根据属性的消耗进行计算的,所以可能会存在最后属性用完了,但是分类
    59 #还是没有算完,这时候就会采用多数表决的方式计算节点分类
    60 def majorityCnt(classList):
    61     classCount = {}
    62     for vote in classList:
    63         if vote not in classCount.keys():
    64             classCount[vote] = 0
    65         classCount[vote] += 1
    66     return max(classCount)         
    67     
    68 def createTree(dataSet, labels):
    69     classList = [example[-1] for example in dataSet]
    70     if classList.count(classList[0]) ==len(classList):#类别相同则停止划分
    71         return classList[0]
    72     if len(dataSet[0]) == 1:#所有特征已经用完
    73         return majorityCnt(classList)
    74     bestFeat = chooseBestFeatureToSplit(dataSet)
    75     bestFeatLabel = labels[bestFeat]
    76     myTree = {bestFeatLabel:{}}
    77     del(labels[bestFeat])
    78     featValues = [example[bestFeat] for example in dataSet]
    79     uniqueVals = set(featValues)
    80     for value in uniqueVals:
    81         subLabels = labels[:]#为了不改变原始列表的内容复制了一下
    82         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, 
    83                                         bestFeat, value),subLabels)
    84     return myTree
    85     
    86 def main():
    87     data,label = createDataSet()
    88     t1 = time.clock()
    89     myTree = createTree(data,label)
    90     t2 = time.clock()
    91     print myTree
    92     print 'execute for ',t2-t1
    93 if __name__=='__main__':
    94     main()

     

    机器学习笔记索引



posted @ 2014-11-15 15:45  mrbean  阅读(30687)  评论(1编辑  收藏  举报