day-8 python自带库实现ID3决策树算法

 

  前一天,我们基于sklearn科学库实现了ID3的决策树程序,本文将基于python自带库实现ID3决策树算法。

 一、代码涉及基本知识

  1、 为了绘图方便,引入了一个第三方treePlotter模块进行图形绘制。该模块使用方法简单,调用模块createPlot接口,传入一个树型结构对象,即可绘制出相应图像。

  2、  在python中,如何定义一个树型结构对象

    可以使用了python自带的字典数据类型来定义一个树型对象。例如下面代码,我们定义一个根节点和两个左右子节点:

    rootNode = {'rootNode': {}}
    leftNode = {'leftNode': {'yes':'yes'}}
    rightNode = {'rightNode': {'no':'no'}}
    rootNode['rootNode']['left'] = leftNode
    rootNode['rootNode']['right'] = rightNode
    treePlotter.createPlot(rootNode)

    通过调用treePlotter模块,绘制出如下树的图像

    

  2、  递归调用

    为了求每个节点的各个子节点,要用到递归的方法来实现,基本思想和二叉树的遍历方法一致,后面我们还会用Python实现一个二叉树源码,此处不再进行介绍。

  3、  此外,还需要对python常用的数据类型及其操作比较了解,例如字典、列表、集合等

二、程序主要流程

 

 

 

三、测试数据集

age

income

student

credit_rating

class_buys_computer

youth

high

no

fair

no

youth

high

no

excellent

no

middle_aged

high

no

fair

yes

senior

medium

no

fair

yes

senior

low

yes

fair

yes

senior

low

yes

excellent

no

middle_aged

low

yes

excellent

yes

youth

medium

no

fair

no

youth

low

yes

fair

yes

senior

medium

yes

fair

yes

youth

medium

yes

excellent

yes

middle_aged

medium

no

excellent

yes

middle_aged

high

yes

fair

yes

senior

medium

no

excellent

no

四、程序代码

         1、计算测试集熵及信息增益        

# 求最优的根节点
def chooseBestFeatureToSplit(dataset,headerList):
    # 定义一个初始值
    bestInfoGainRate = 0.0
    bestFeature = 0
    # 求特征列项的数量
    numFeatures = len(dataset[0]) -1
    # 获取整个测试数据集的熵
    baseShnnonEnt = calcShannonEnt(dataset)
    print("total's shannonEnt = %f" % (baseShnnonEnt))
    # 遍历每一个特征列,求取信息增益
    for i in range(numFeatures):
        # 获取某一列所有特征值
        featureList = [example[i] for example in dataset]
        uniqueVals = set(featureList)
        newEntropy = 0.0
        # 求得某一列某一个特征值的概率和熵
        newShannonEnt = 0.0
        for value in uniqueVals:
            # 计算熵
            subDataset = splitDataSet(dataset,i,value)
            newEntropy = calcShannonEnt(subDataset)
            # 计算某一列某一个特征值的概率
            newProbability = len(subDataset) / float(len(dataset))
            newShannonEnt += newProbability*calcShannonEnt(subDataset)
        infoGainRate = baseShnnonEnt - newShannonEnt
        print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
        if infoGainRate > bestInfoGainRate:
            bestInfoGainRate = infoGainRate
            bestFeature = i
    return bestFeature

  该结果和前一天计算结果一致,age特征对应信息增益最大,因此设为根节点:

        

         2、程序源码

         treePlotter.py        

import matplotlib.pyplot as plt

# 定义决策树决策结果属性
descisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    # nodeTxt为要显示的文本,centerNode为文本中心点, nodeType为箭头所在的点, parentPt为指向文本的点
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                             xytext=centerPt, textcoords='axes fraction',
                              va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]     # 这个是改的地方,原来myTree.keys()返回的是dict_keys类,不是列表,运行会报错。有好几个地方这样
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = {'xticks': None, 'yticks': None}
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotTree.totalW = float(getNumLeafs(inTree))     # 全局变量宽度 = 叶子数目
    plotTree.totalD = float(getTreeDepth(inTree))     # 全局变量高度 = 深度
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    # cntrPt文本中心点, parentPt指向文本中心的点
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, descisionNode)
    seconDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in seconDict.keys():
        if type(seconDict[key]).__name__ == 'dict':
            plotTree(seconDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(seconDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va='center', ha='center', rotation=30)

         decision_tree_ID3.py

# 导入库
import csv
import math
import operator
import treePlotter


# 导入数据集
def readDataset(file_path,file_mode):
    allElectronicsData = open(file_path, file_mode)
    reader = csv.reader(allElectronicsData)
    # 读取特征名称
    headers = next(reader)
    # 读取测试数据集
    dataset = []
    for row in reader:
        dataset.append(row)
    return headers,dataset

# 求某个数据集的熵
def calcShannonEnt(dataset):
    shannonEnt = 0.0
    labelList = {}
    for vec_now in dataset:
        labelValue = vec_now[-1]
        if vec_now[-1] not in labelList.keys():
            labelList[labelValue] = 0
        labelList[labelValue] += 1
    for labelKey in labelList:
        probability = float(labelList[labelKey] / len(dataset))
        shannonEnt -= probability*math.log(probability,2)
    return shannonEnt

# 根据给定的列特征值,分理出给定的特征量
def splitDataSet(dataset,feature_seq,value):
    new_dataset = []
    for vec_row in dataset:
        feature_Value = vec_row[feature_seq]
        if feature_Value == value:
            temp_vec = []
            temp_vec = vec_row[:feature_seq]
            temp_vec.extend(vec_row[feature_seq+1:])
            new_dataset.append(temp_vec)
    return new_dataset

# 求最优的根节点
def chooseBestFeatureToSplit(dataset,headerList):
    # 定义一个初始值
    bestInfoGainRate = 0.0
    bestFeature = 0
    # 求特征列项的数量
    numFeatures = len(dataset[0]) -1
    # 获取整个测试数据集的熵
    baseShnnonEnt = calcShannonEnt(dataset)
    #print("total's shannonEnt = %f" % (baseShnnonEnt))
    # 遍历每一个特征列,求取信息增益
    for i in range(numFeatures):
        # 获取某一列所有特征值
        featureList = [example[i] for example in dataset]
        uniqueVals = set(featureList)
        newEntropy = 0.0
        # 求得某一列某一个特征值的概率和熵
        newShannonEnt = 0.0
        for value in uniqueVals:
            # 计算熵
            subDataset = splitDataSet(dataset,i,value)
            newEntropy = calcShannonEnt(subDataset)
            # 计算某一列某一个特征值的概率
            newProbability = len(subDataset) / float(len(dataset))
            newShannonEnt += newProbability*calcShannonEnt(subDataset)
        infoGainRate = baseShnnonEnt - newShannonEnt
        #print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
        if infoGainRate > bestInfoGainRate:
            bestInfoGainRate = infoGainRate
            bestFeature = i
    return bestFeature

# 标签判定,通过少数服从多数原则
def majorityCnt(classList):
    classcount = {}
    for cl in classList:
        if cl not in classcount.keys():
            classcount[cl] = 0
        classcount[cl] += 1
    sortedClassCount = sorted(classcount.items(),key = operator.itemgetter(1),reverse= True)
    return sortedClassCount[0][0]

# 创建一个决策树
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    # 1 所有特征值都是相同的时候直接返回
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 2 遍历完所有特征值,投票原则,返回出现次数最多的标签
    if len(dataSet[0])  == 1:
        return majorityCnt(classList)
    # 3 如果不满足上面两者,求最优特征
    bestFeature = chooseBestFeatureToSplit(dataSet,labels)
    bestFeatureLabel = labels[bestFeature]
    myTree = {bestFeatureLabel: {}}
    del (labels[bestFeature])
    featurValues = [example[bestFeature] for example in dataSet]
    uniqueVals = set(featurValues)
    # 使用递归的方法,获得整个树
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)
    return myTree

def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

def classifyAll(inputTree, featLabels, testDataSet):
    classLabelAll = []
    for testVec in testDataSet:
        classLabelAll.append(classify(inputTree, featLabels, testVec))
    return classLabelAll

def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)

def main():
    # 读取数据集
    labels, dataSet = readDataset(file_path=r'D:\test.csv', file_mode='r')
    labels_tmp = labels[:] # 拷贝,createTree会改变labels
    desicionTree = createTree(dataSet, labels_tmp)
    storeTree(desicionTree, 'classifierStorage.txt')
    desicionTree = grabTree('classifierStorage.txt')
    treePlotter.createPlot(desicionTree)
    testSet = [['youth', 'high', 'no', 'fair', 'no']]
    print('classifyResult:\n', classifyAll(desicionTree, labels, testSet))

if __name__ == '__main__':
    main()

五、测试结果及结论

 

  我们从上面求解信息增益的公式中,其实可以看出,信息增益准则其实是对可取值数目较多的属性有所偏好!
  现在假如我们把数据集中的“编号”也作为一个候选划分属性。我们可以算出“编号”的信息增益是0.998
  因为每一个样本的编号都是不同的(由于编号独特唯一,条件熵为0了,每一个结点中只有一类,纯度非常高啊),也就是说,来了一个预测样本,你只要告诉我编号,其它特征就没有用了,这样生成的决策树显然不具有泛化能力。

 

  参考链接:

  http://www.cnblogs.com/wsine/p/5180310.html

  https://zhuanlan.zhihu.com/p/26760551

 

posted @ 2018-04-05 05:03  派森蛙  阅读(2237)  评论(0编辑  收藏  举报