决策树之python实现ID3算法(例子)
引言
决策树从本质上是从训练数据集上训练处一组分类规则,完全依据训练数据,所得规则容易发生过拟合,这也是决策树的缺点,不过可以通过决策树的剪枝,来提高决策树的泛化能力。
由此,决策树的创建可包括三部分:特征选择、决策树的生成以及决策树的剪枝;决策树的应用包括:分类、回归以及特征选择。
决策树最经典的算法包括:ID3、C4.5以及CART算法,ID3与C4.5算法相似,C4.5在特征选择时选用的信息准则是信息增益比,而ID3用的是信息增益;因为信息增益偏向于选择具有较多可能取值的特征
基于信息论的特征选择
注意:熵表示随机变量的不确定性,熵值越大表示随机变量含有的信息越少,变量的不确定性越大。
1) 香侬定义一个数据的信息可按下式计算 (此处是以2为底的对数):
2)熵表示一个数据集合信息的期望,可按下式计算:(该式不理解,可想象下,求变量期望的公式,p(xi) 为变量 xi以及信息 l(xi) 的概率,概率乘以变量(信息)即为变量(信息)的期望)
3)特征 AA 对数据集 DD 的信息增益为:
上式中,设训练数据集为D,其样本容量为|D|,即样本个数,设共有K个类Ck,k=1,2,...,K Ck,k=1,2,...,K , |Ck| 为Ck的样本个数,
根据特征A 的取值将 D 划分为n个子集D1,D2,...,Dn, |Di|为 Di的样本数,Dik=Di⋂Ck ,|Dik|为 Dik的样本个数.
如下表和图所示:
feature1(A) | feature1 | feature3 | labels |
---|---|---|---|
a1 | b1 | c1 | y |
a1 | b2 | c2 | n |
a1 | b1 | c2 | n |
a1 | b1 | c2 | n |
a2 | b1 | c1 | y |
a2 | b2 | c2 | y |
a2 | b1 | c1 | n |
python实现:
1 def calcShannonEnt(dataset):#计算熵 2 numSamples = len(dataset) 3 labelCounts = {} 4 for allFeatureVector in dataset: 5 currentLabel = allFeatureVector[-1] 6 if currentLabel not in labelCounts.keys(): 7 labelCounts[currentLabel] = 0 8 labelCounts[currentLabel] += 1 9 entropy = 0.0 10 for key in labelCounts: 11 property = float(labelCounts[key])/numSamples 12 entropy -= property * log(property,2) 13 return entropy 14 def BestFeatToGetSubdataset(dataset): 15 #下边这句实现:除去最后一列类别标签列剩余的列数即为特征个数 16 numFeature = len(dataset[0]) - 1 17 baseEntropy = calcShannonEnt(dataset) 18 bestInfoGain = 0.0; bestFeature = -1 19 for i in range(numFeature):#i表示该函数传入的数据集中每个特征 20 # 下边这句实现抽取特征i在数据集中的所有取值 21 feat_i_values = [example[i] for example in dataset] 22 uniqueValues = set(feat_i_values) 23 feat_i_entropy = 0.0 24 for value in uniqueValues: 25 subDataset = getSubDataset(dataset,i,value) 26 #下边这句计算pi,实现计算信息增益最大的特征 27 prob_i = len(subDataset)/float(len(dataset)) 28 feat_i_entropy += prob_i * calcShannonEnt(subDataset) 29 infoGain_i = baseEntropy - feat_i_entropy 30 if (infoGain_i > bestInfoGain): 31 bestInfoGain = infoGain_i 32 bestFeature = i 33 return bestFeature
决策树生成
决策树生成可用下边的流程图表示:
ID3算法python实现代码:
1 # -*- coding: utf-8 -*- 2 from math import log 3 import operator 4 import pickle 5 ''' 6 输入:原始数据集、子数据集(最后一列为类别标签,其他为特征列) 7 功能:计算原始数据集、子数据集(某一特征取值下对应的数据集)的香农熵 8 输出:float型数值(数据集的熵值) 9 ''' 10 def calcShannonEnt(dataset): 11 numSamples = len(dataset) 12 labelCounts = {} 13 for allFeatureVector in dataset: 14 currentLabel = allFeatureVector[-1] 15 if currentLabel not in labelCounts.keys(): 16 labelCounts[currentLabel] = 0 17 labelCounts[currentLabel] += 1 18 entropy = 0.0 19 for key in labelCounts: 20 property = float(labelCounts[key])/numSamples 21 entropy -= property * log(property,2) 22 return entropy 23 24 ''' 25 输入:无 26 功能:封装原始数据集 27 输出:数据集、特征标签 28 ''' 29 def creatDataSet(): 30 dataset = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']] 31 labels = ['no surfacing','flippers'] 32 return dataset,labels 33 34 ''' 35 输入:数据集、数据集中的某一特征所在列的索引、该特征某一可能取值(例如,(原始数据集、0,1 )) 36 功能:取出在该特征取值下的子数据集(子集不包含该特征) 37 输出:子数据集 38 ''' 39 def getSubDataset(dataset,colIndex,value): 40 subDataset = [] #用于存储子数据集 41 for rowVector in dataset: 42 if rowVector[colIndex] == value: 43 #下边两句实现抽取除第colIndex列特征的其他特征取值 44 subRowVector = rowVector[:colIndex] 45 subRowVector.extend(rowVector[colIndex+1:]) 46 #将抽取的特征行添加到特征子数据集中 47 subDataset.append(subRowVector) 48 return subDataset 49 50 ''' 51 输入:数据集 52 功能:选择最优的特征,以便得到最优的子数据集(可简单的理解为特征在决策树中的先后顺序) 53 输出:最优特征在数据集中的列索引 54 ''' 55 def BestFeatToGetSubdataset(dataset): 56 #下边这句实现:除去最后一列类别标签列剩余的列数即为特征个数 57 numFeature = len(dataset[0]) - 1 58 baseEntropy = calcShannonEnt(dataset) 59 bestInfoGain = 0.0; bestFeature = -1 60 for i in range(numFeature):#i表示该函数传入的数据集中每个特征 61 # 下边这句实现抽取特征i在数据集中的所有取值 62 feat_i_values = [example[i] for example in dataset] 63 uniqueValues = set(feat_i_values) 64 feat_i_entropy = 0.0 65 for value in uniqueValues: 66 subDataset = getSubDataset(dataset,i,value) 67 #下边这句计算pi 68 prob_i = len(subDataset)/float(len(dataset)) 69 feat_i_entropy += prob_i * calcShannonEnt(subDataset) 70 infoGain_i = baseEntropy - feat_i_entropy 71 if (infoGain_i > bestInfoGain): 72 bestInfoGain = infoGain_i 73 bestFeature = i 74 return bestFeature 75 76 ''' 77 输入:子数据集的类别标签列 78 功能:找出该数据集个数最多的类别 79 输出:子数据集中个数最多的类别标签 80 ''' 81 def mostClass(ClassList): 82 classCount = {} 83 for class_i in ClassList: 84 if class_i not in classCount.keys(): 85 classCount[class_i] = 0 86 classCount[class_i] += 1 87 sortedClassCount = sorted(classCount.iteritems(), 88 key=operator.itemgetter(1),reverse = True) 89 return sortedClassCount[0][0] 90 91 ''' 92 输入:数据集,特征标签 93 功能:创建决策树(直观的理解就是利用上述函数创建一个树形结构) 94 输出:决策树(用嵌套的字典表示) 95 ''' 96 def creatTree(dataset,labels): 97 classList = [example[-1] for example in dataset] 98 #判断传入的dataset中是否只有一种类别,是,返回该类别 99 if classList.count(classList[0]) == len(classList): 100 return classList[0] 101 #判断是否遍历完所有的特征,是,返回个数最多的类别 102 if len(dataset[0]) == 1: 103 return mostClass(classList) 104 #找出最好的特征划分数据集 105 bestFeat = BestFeatToGetSubdataset(dataset) 106 #找出最好特征对应的标签 107 bestFeatLabel = labels[bestFeat] 108 #搭建树结构 109 myTree = {bestFeatLabel:{}} 110 del (labels[bestFeat]) 111 #抽取最好特征的可能取值集合 112 bestFeatValues = [example[bestFeat] for example in dataset] 113 uniqueBestFeatValues = set(bestFeatValues) 114 for value in uniqueBestFeatValues: 115 #取出在该最好特征的value取值下的子数据集和子标签列表 116 subDataset = getSubDataset(dataset,bestFeat,value) 117 subLabels = labels[:] 118 #递归创建子树 119 myTree[bestFeatLabel][value] = creatTree(subDataset,subLabels) 120 return myTree 121 122 ''' 123 输入:测试特征数据 124 功能:调用训练决策树对测试数据打上类别标签 125 输出:测试特征数据所属类别 126 ''' 127 def classify(inputTree,featlabels,testFeatValue): 128 firstStr = inputTree.keys()[0] 129 secondDict = inputTree[firstStr] 130 featIndex = featlabels.index(firstStr) 131 for firstStr_value in secondDict.keys(): 132 if testFeatValue[featIndex] == firstStr_value: 133 if type(secondDict[firstStr_value]).__name__ == 'dict': 134 classLabel = classify(secondDict[firstStr_value],featlabels,testFeatValue) 135 else: classLabel = secondDict[firstStr_value] 136 return classLabel 137 138 139 ''' 140 输入:训练树,存储的文件名 141 功能:训练树的存储 142 输出: 143 ''' 144 def storeTree(trainTree,filename): 145 146 fw = open(filename,'w') 147 pickle.dump(trainTree,fw) 148 fw.close() 149 def grabTree(filename): 150 151 fr = open(filename) 152 return pickle.load(fr) 153 154 155 if __name__ == '__main__': 156 dataset,labels = creatDataSet() 157 storelabels = labels[:]#复制label 158 trainTree = creatTree(dataset,labels) 159 classlabel = classify(trainTree,storelabels,[0,1]) 160 print classlabel
本文来自于:
谢谢博主