随笔 - 402  文章 - 1 评论 - 20 阅读 - 113万
< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

1、简单概念描述

       决策树的类型有很多,有CART、ID3和C4.5等,其中CART是基于基尼不纯度(Gini)的,这里不做详解,而ID3和C4.5都是基于信息熵的,它们两个得到的结果都是一样的,本次定义主要针对ID3算法。下面我们介绍信息熵的定义。

      p(ai):事件ai发生的概率

  I(ai)=-log2(p(ai)):表示为事件ai的不确定程度,称为ai的自信息量

  H=sum(p(ai)*I(ai)):称为信源S的平均信息量—信息熵

  Gain = BaseEntropy – newEntropy:信息增益

    决策树学习采用的是自顶向下的递归方法,其基本思想是以信息熵为度量构造一棵熵值下降最快的树,到叶子节点处的熵值为零,此时每个叶节点中的实例都属于同一类。ID3的原理是基于信息熵增益Gain达到最大,设原始问题的标签有正例和负例,p和n表示其相应的个数。则原始问题的信息熵为

    其中N为该特征所取值的个数,比如{rain,sunny},则N即为2

  ID3易出现的问题:如果是取值更多的属性,更容易使得数据更“纯”(尤其是连续型数值),其信息增益更大,决策树会首先挑选这个属性作为树的顶点。结果训练出来的形状是一棵庞大且深度很浅的树,这样的划分是极为不合理的。 此时可以采用C4.5来解决,C4.5的思想是最大化Gain除以下面这个公式即得到信息增益率:

  其中底为2

2、决策树的优缺点

优点:计算复杂度不高,输出结果易于理解,对中间值缺失不敏感,可以处理不相关特征数据

缺点:可能产生过度匹配问题

适用数据类型:数值型和标称型

3、python代码的实现

以下的代码根据这些数据理解

数据1中包含5个海洋动物,特征包括:不浮出水面是否可以生存,以及是否有脚蹼。我们可以将这些动物分成两类:鱼类和非鱼类。

  不浮出水面是否可以生存 是否有脚蹼 属于鱼类
1
2
3
4
5

 

  特征[0](no surfacing) 特征[1](flippers) 特征[-1]fish
dataSet[0] 1 1 yes
dataSet[1] 1 1 yes
dataSet[2] 0 1 no
dataSet[3] 0 1 no
dataSet[4] 0 1 no

创建名为trees.py的文件,下面代码内容都在此文件中。

(1)计算信息熵

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# -*- coding: utf-8 -*-<br>#计算给定数据集的香农熵
def calcShannonEnt(dataSet): 
  numEntries=len(dataSet)  #数据实例总数
  labelCounts={}  #对类别数量创建了一个数据字典,键值是最后一列的数值
  for featVec in dataSet:   #featVec表示特征集
      currentLabel=featVec[-1]    # currentLabel表示当前键值,featVec[-1]表示数据集中的最后一列
      #如果当前键值不存在,扩展字典将当前键值加入字典,设置当前键值表示的类别数量为0
      if currentLabel not in labelCounts.keys():
          labelCounts[currentLabel]=0
      #如果当前键值存在,则类别数量累加
      labelCounts[currentLabel]+=1
  shannonEnt=0.0
  for key in labelCounts:
      prob=float(labelCounts[key])/numEntries #每个键值都记录了当前类别出现的次数
      shannonEnt -=prob*log(prob,2)
  return shannonEnt 

(2)创建数据集

1
2
3
4
5
#创建数据集
def createDataSet():
    dataSet=[[1,1,'yes'],[1,1,'yes'],[0,1,'no'],[0,1,'no'],[0,1,'no']]
    labels=['no surfacing','flippers']
    return dataSet,labels

在python命令提示符下输入下列命令:

复制代码
1 >>> import trees
2 >>> reload(trees)
3 <module 'trees' from 'E:\python excise\trees.pyc'>
4 >>> myDat,labels=trees.createDataSet()
5 >>> myDat
6 [[1, 1, 'yes'], [1, 1, 'yes'], [0, 1, 'no'], [0, 1, 'no'], [0, 1, 'no']]
7 >>> trees.calcShannonEnt(myDat)
8 0.9709505944546686
9 >>> 
复制代码

熵越高,则混合的数据越多,在数据集中添加更多的分类,观察熵是如何变化的,这里增加第三个名为maybe的分类,测试熵的变化:

>>> myDat[0][-1]='maybe'  
>>> myDat
[[1, 1, 'maybe'], [1, 1, 'yes'], [0, 1, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> trees.calcShannonEnt(myDat)
1.3709505944546687

得到熵后,我们可以按照获取最大信息增益的方法划分数据集

(3)划分数据集

 我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式

1
2
3
4
5
6
7
8
9
10
#按照给定特征划分数据集
#dataSet:待划分的数据集,axis:划分数据集的特征,value:需要返回的特征的值
def splitDataSet(dataSet,axis,value):
    retDataSet=[]   #为了不修改原始数据dataSet,创建一个新的列表对象
    for featVec in dataSet:
        if featVec[axis]==value:    
            reducedFeatVec=featVec[:axis]   #获取从第0列到特征列的数据
            reducedFeatVec.extend(featVec[axis+1:])  #获取从特征列之后的数据
            retDataSet.append(reducedFeatVec) #目前reducedFeatVec表示除了特征列的数据
    return retDataSet
复制代码
1 >>> reload(trees)
2 <module 'trees' from 'E:\python excise\trees.pyc'>
3 >>> myDat,labels=trees.createDataSet()
4 >>> myDat
5 [[1, 1, 'yes'], [1, 1, 'yes'], [0, 1, 'no'], [0, 1, 'no'], [0, 1, 'no']]
6 >>> trees.splitDataSet(myDat,0,1)
7 [[1, 'yes'], [1, 'yes']]
8 >>> trees.splitDataSet(myDat,0,0)
9 [[1, 'no'], [1, 'no'], [1, 'no']]
复制代码

(4)选择最好的特征进行划分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
  numFeatures=len(dataSet[0])-1        #减去类别那一列
  baseEntropy=calcShannonEnt(dataSet)   #计算整个数据集的原始香农熵
  bestInfoGain=0.0;bestFeature=-1  #现在最好的特征是数据集中的最后一列<br>  #i=0,新熵,增益<br>  #i=1,新熵,增益
  for i in range(numFeatures):    #循环遍历数据集中的所有特征
    featList=[example[i] for example in dataSet]  #获取第i个特征所有可能的取值,特征0一个列表,特征1一个列表...
    uniqueVals=set(featList)  #集合数据类型(set)与列表类型相似,不同之处仅在于集合类型中每个值互不相同
    newEntropy=0.0
    for value in uniqueVals:
      subDataSet=splitDataSet(dataSet,i,value)  #划分后的数据集
      prob=len(subDataSet)/float(len(dataSet))
      newEntropy+=prob*calcShannonEnt(subDataSet) #求划分完的数据集的熵
    infoGain=baseEntropy-newEntropy
    if(infoGain>bestInfoGain):
      bestInfoGain=infoGain
      bestFeature=i          
  return bestFeature

注意:这里数据集需要满足以下两个办法:

<1>所有的列元素都必须具有相同的数据长度

<2>数据的最后一列或者每个实例的最后一个元素是当前实例的类别标签。

1 >>> reload(trees)
2 <module 'trees' from 'E:\python excise\trees.pyc'>
3 >>> myDat,labels=trees.createDataSet()
4 >>> trees.chooseBestFeatureToSplit(myDat)
5 0

(5)创建树的代码

Python用字典类型来存储树的结构,返回的结果是myTree-字典

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#创建树的函数代码
def createTree(dataSet,labels):
  classList=[example[-1] for example in dataSet]
  if classList.count(classList[0])==len(classList):  #类别完全相同规则停止继续划分
    return classList[0]
  if len(dataSet[0])==1: #确认至少有数据集<br>    return majorityCnt(classList)
  bestFeat=chooseBestFeatureToSplit(dataSet)
  bestFeatLabel=labels[bestFeat]
  myTree={bestFeatLabel:{}}
  del(labels[bestFeat])  #得到列表包含的所有属性
  featValues=[example[bestFeat] for example in dataSet]
  uniqueVals=set(featValues)
  for value in uniqueVals:
    subLabels=labels[:]
    myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
  return myTree 

其中递归结束当且仅当该类别中标签完全相同或者遍历所有的特征此时返回次数最多的

1 >>> reload(trees)
2 <module 'trees' from 'E:\python excise\trees.pyc'>
3 >>> myDat,labels=trees.createDataSet()
4 >>> myTree=trees.createTree(myDat,labels)
5 >>> myTree
6 {'no surfacing': {0: 'no', 1: 'yes'}}

其中当所有的特征都用完时,采用多数表决的方法来决定该叶子节点的分类,即该叶节点中属于某一类最多的样本数,那么我们就说该叶节点属于那一类。即为如果数据集已经处理了所有的属性,但是类标签依然不是唯一的,此时我们要决定如何定义该叶子节点,在这种情况下,我们通常采用多数表决的方法来决定该叶子节点的分类。代码如下:

1
2
3
4
5
6
7
def majorityCnt(classList):
  classCount={}
  for vote in classList:
    if vote not in classCount.keys():classCount[vote]=0
    classCount[vote]+=1
 sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
  return sortedClassCount[0][0]

(6)使用决策树执行分类

1
2
3
4
5
6
7
8
9
10
11
#测试算法:使用决策树执行分类
def classify(inputTree,featLabels,testVec):
  firstStr=inputTree.keys()[0]
  secondDict=inputTree[firstStr]
  featIndex=featLabels.index(firstStr)
  for key in secondDict.keys():
    if testVec[featIndex]==key:
      if type(secondDict[key]).__name__=='dict':
        classLabel=classify(secondDict[key],featLabels,testVec)
      else:classLabel=secondDict[key]
  return classLabel
1 >>> import trees
2 >>> myDat,labels=trees.createDataSet()
3 >>> labels
4 ['no surfacing', 'flippers']
5 >>> trees.classify(myTree,labels,[1,0])
6 'no'
7 >>> trees.classify(myTree,labels,[1,1])
8 'yes'

注意递归的思想很重要。

(7)决策树的存储

构造决策树是一个很耗时的任务。为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用python模块pickle序列化对象,序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。

1
2
3
4
5
6
7
8
9
10
#使用算法:决策树的存储
def storeTree(inputTree,filename):
  import pickle
  fw=open(filename,'w')
  pickle.dump(inputTree,fw)
  fw.close()
def grabTree(filename):
  import pickle
  fr=open(filename)
  return pickle.load(fr)
1 >>> reload(trees)
2 >>><module 'tree' from 'trees.py'>
3 >>> trees.storeTree(myTree,'classifierStorage.txt')
4 >>> trees.grabTree('classifierStorage.txt')
5 {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

classifierStorage.txt如下:

补充:

用matplotlib注解上述形成的决策树

Matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注释。注解通常用于解释数据的内容。

创建名为treePlotter.py文件,下面代码都在此文件中

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/python
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
from numpy import *
import operator
#定义文本框和箭头格式
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#绘制箭头的注解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
def createPlot():
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    createPlot.ax1=plt.subplot(111,frameon=False)
    plotNode(U'决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode(U'叶节点',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()
#获取叶节点的数目和树的层数
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        else: numLeafs +=1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth
 
def retrieveTree(i):
    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\
                 {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]
#在父节点间填充文本信息      
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
#计算宽和高
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=myTree.keys()[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)   #计算父节点和子节点的中间位置
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
        plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

 

其中index方法为查找当前列表中第一个匹配firstStr的元素 返回的为索引。

 

posted on   chamie  阅读(621)  评论(0编辑  收藏  举报
编辑推荐:
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
阅读排行:
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
点击右上角即可分享
微信分享提示