决策树---ID3算法(介绍及Python实现)

决策树---ID3算法

 

决策树:

以天气数据库的训练数据为例。

 

Outlook

Temperature

Humidity

Windy

PlayGolf?

sunny

85

85

FALSE

no

sunny

80

90

TRUE

no

overcast

83

86

FALSE

yes

rainy

70

96

FALSE

yes

rainy

68

80

FALSE

yes

rainy

65

70

TRUE

no

overcast

64

65

TRUE

yes

sunny

72

95

FALSE

no

sunny

69

70

FALSE

yes

rainy

75

80

FALSE

yes

sunny

75

70

TRUE

yes

overcast

72

90

TRUE

yes

overcast

81

75

FALSE

yes

rainy

71

91

TRUE

no

这个例子是根据报告天气条件的记录来决定是否外出打高尔夫球。

 

作为分类器,决策树是一棵有向无环树。

由根节点、叶节点、内部点、分割属性、分割判断规则构成

 

生成阶段:决策树的构建和决策树的修剪。

根据分割方法的不同:有基于信息论(Information Theory的方法和基于最小GINI指数(lowest GINI index的方法。对应前者的常见方法有ID3、C4.5,后者的有CART

 ID3 算法

       ID3的基本概念是:

1)  决策树中的每一个非叶子节点对应着一个特征属性,树枝代表这个属性的值。一个叶节点代表从树根到叶节点之间的路径所对应的记录所属的类别属性值。这就是决策树的定义。

2)  在决策树中,每一个非叶子节点都将与属性中具有最大信息量的特征属性相关联。

3)  熵通常是用于测量一个非叶子节点的信息量大小的名词。

 

热力学中表征物质状态的参量之一,用符号S表示,其物理意义是体系混乱程度的度量。热力学第二定律(second law of thermodynamics),热力学基本定律之一,又称“熵增定律”,表明在自然过程中,一个孤立系统的总混乱度(即“熵”)不会减小。

在信息论中,变量的不确定性越大,熵也就越大,把它搞清楚所需要的信息量也就越大。信息熵是信息论中用于度量信息量的一个概念。一个系统越是有序,信息熵就越低;反之,一个系统越是混乱,信息熵就越高。所以,信息熵也可以说是系统有序化程度的一个度量。

 信息增益的计算

定义1:若存在个相同概率的消息,则每个消息的概率是,一个消息传递的信息量为。若有16个事件,则,需要4个比特来代表一个消息。

定义2若给定概率分布则由该分布传递的信息量称为的熵,即

 

例:若是,则是1;若是,则是0.92;若

是,则是0(注意概率分布越均匀,其信息量越大)

定义3若一个记录的集合根据类别属性的值被分为相互独立的类,则识别的一个元素所属哪个类别所需要的信息量是,其中是的概率分布,即

 

 

仍以天气数据库的数据为例。我们统计了14天的气象数据(指标包括outlook,temperature,humidity,windy),并已知这些天气是否打球(play)。如果给出新一天的气象指标数据,判断一下会不会去打球。在没有给定任何天气信息时,根据历史数据,我们知道一天中打球的概率是9/14,不打的概率是5/14。此时的熵为:

 

定义4:若我们根据某一特征属性将分成集合,则确定中的一个元素类的信息量可通过确定的加权平均值来得到,即的加权平均值为:

 

 

 

Outlook

temperature

humidity

windy

play

 

yes

no

 

 

 

yes

no

yes

no

sunny

2

3

False

6

2

9

5

overcast

4

0

True

3

3

 

 

rainy

3

2

 

 

 

 

 

 

针对属性Outlook,我们来计算

定义5:将信息增益定义为:

 

即增益的定义是两个信息量之间的差值,其中一个信息量是需确定的一个元素的信息量,另一个信息量是在已得到的属性的值后确定的一个元素的信息量,即信息增益与属性相关。

针对属性Outlook的增益值:

 

若用属性windy替换outlook,可以得到,。即outlook比windy取得的信息量大。

ID3算法的Python实现

import math 
import operator

def calcShannonEnt(dataset):
    numEntries = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] +=1
        
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob*math.log(prob, 2)
    return shannonEnt
    
def CreateDataSet():
    dataset = [[1, 1, 'yes' ], 
               [1, 1, 'yes' ], 
               [1, 0, 'no'], 
               [0, 1, 'no'], 
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataset, labels

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    numberFeatures = len(dataSet[0])-1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0;
    bestFeature = -1;
    for i in range(numberFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy =0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
    classCount ={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]=1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 
    return sortedClassCount[0][0]
 

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

        
        
myDat,labels = CreateDataSet()
createTree(myDat,labels)

运行结果如下:

 

posted @ 2017-06-04 12:16  fcyh  阅读(17949)  评论(0编辑  收藏  举报