决策树ID3算法python实现 -- 《机器学习实战》

 1 from math import log
 2 import numpy as np
 3 import matplotlib.pyplot as plt
 4 import operator
 5 
 6 #计算给定数据集的香农熵
 7 def calcShannonEnt(dataSet):
 8     numEntries = len(dataSet)
 9     labelCounts = {}
10     for featVec in dataSet:                         #|
11         currentLabel = featVec[-1]                  #|
12         if currentLabel not in labelCounts.keys():  #|获取标签类别取值空间(key)及出现的次数(value)
13             labelCounts[currentLabel] = 0           #|
14         labelCounts[currentLabel] += 1              #|
15     shannonEnt = 0.0
16     for key in labelCounts:                         #|
17         prob = float(labelCounts[key])/numEntries   #|计算香农熵
18         shannonEnt -= prob * log(prob, 2)           #|
19     return shannonEnt
20 
21 #创建数据集
22 def createDataSet():
23     dataSet = [[1,1,'yes'],
24                [1,1,'yes'],
25                [1,0,'no'],
26                [0,1,'no'],
27                [0,1,'no']]
28     labels = ['no surfacing', 'flippers']
29     return dataSet, labels
30 
31 #按照给定特征划分数据集
32 def splitDataSet(dataSet, axis, value):
33     retDataSet = []
34     for featVec in dataSet:                         #|
35         if featVec[axis] == value:                  #|
36             reducedFeatVec = featVec[:axis]         #|抽取出符合特征的数据
37             reducedFeatVec.extend(featVec[axis+1:]) #|
38             retDataSet.append(reducedFeatVec)       #|
39     return retDataSet
40 
41 #选择最好的数据集划分方式
42 def chooseBestFeatureToSplit(dataSet):
43     numFeatures = len(dataSet[0]) - 1
44     basicEntropy = calcShannonEnt(dataSet)
45     bestInfoGain = 0.0; bestFeature = -1
46     for i in range(numFeatures):        #计算每一个特征的熵增益
47         featlist = [example[i] for example in dataSet]
48         uniqueVals = set(featlist)
49         newEntropy = 0.0
50         for value in uniqueVals:        #计算每一个特征的不同取值的熵增益
51             subDataSet = splitDataSet(dataSet, i, value)
52             prob = len(subDataSet)/float(len(dataSet))
53             newEntropy += prob * calcShannonEnt(subDataSet) #不同取值的熵增加起来就是整个特征的熵增益
54         infoGain = basicEntropy - newEntropy
55         if (infoGain > bestInfoGain):   #选择最高的熵增益作为划分方式
56             bestInfoGain = infoGain
57             bestFeature = i
58     return bestFeature
59 #挑选出现次数最多的类别
60 def majorityCnt(classList):
61     classCount={}
62     for vote in classList:
63         if vote not in classCount.keys():
64             classCount[vote] = 0
65         classCount[vote] += 1
66     sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True)
67     return sortedClassCount[0][0]
68 
69 def createTree(dataSet, labels):
70     classList = [example[-1] for example in dataSet]
71     if classList.count(classList[0]) == len(classList): #停止条件一:判断所有类别标签是否相同,完全相同则停止继续划分
72         return classList[0]
73     if len(dataSet[0]) == 1:    #停止条件二:遍历完所有特征时返回出现次数最多的
74         return majorityCnt(classList)
75     bestFeat = chooseBestFeatureToSplit(dataSet)    #得到列表包含的所有属性值
76     bestFeatLabel = labels[bestFeat]
77     myTree = {bestFeatLabel:{}}
78     del(labels[bestFeat])
79     featValues = [example[bestFeat] for example in dataSet]
80     uniqueVals = set(featValues)
81     for value in uniqueVals:
82         subLabels = labels[:]
83         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
84     return myTree
85 
86 # Simple unit test of func: createDataSet()
87 myDat, labels = createDataSet()
88 print (myDat)
89 #print (labels)
90 # Simple unit test of func: splitDataSet()
91 splitData = splitDataSet(myDat,0,1)
92 print (splitData)
93 # Simple unit test of func: chooseBestFeatureToSplit()
94 chooseResult = chooseBestFeatureToSplit(myDat)
95 print (chooseResult)
96 # Simple unit test of func: createTree(
97 myDat, labels = createDataSet()
98 myTree = createTree(myDat, labels)
99 print(myTree)

Output:

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]
0
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

 

Reference:

《机器学习实战》

posted @ 2017-11-13 10:13  刘川枫  阅读(431)  评论(0编辑  收藏  举报