(二)《机器学习》(周志华)第4章 决策树 笔记 理论及实现——“西瓜树”——CART决策树

CART决策树

(一)《机器学习》(周志华)第4章 决策树 笔记 理论及实现——“西瓜树”

参照上一篇ID3算法实现的决策树(点击上面链接直达),进一步实现CART决策树。

其实只需要改动很小的一部分就可以了,把原先计算信息熵和信息增益的部分换做计算基尼指数,选择最优属性的时候,选择最小的基尼指数即可。

#导入模块
import pandas as pd
import numpy as np
from collections import Counter

#数据获取与处理
def getData(filePath):
    data = pd.read_excel(filePath)
    return data

def dataDeal(data):
    dataList = np.array(data).tolist()
    dataSet = [element[1:] for element in dataList]
    return dataSet

#获取属性名称
def getLabels(data):
    labels = list(data.columns)[1:-1]
    return labels

#获取类别标记
def targetClass(dataSet):
    classification = set([element[-1] for element in dataSet])
    return classification
    
#将分支结点标记为叶结点,选择样本数最多的类作为类标记
def majorityRule(dataSet):
    mostKind = Counter([element[-1] for element in dataSet]).most_common(1)
    majorityKind = mostKind[0][0]
    return majorityKind

##计算基尼值
def calculateGini(dataSet):
    classColumnCnt = Counter([element[-1] for element in dataSet])
    gini = 0
    for symbol in classColumnCnt:
        p_k = classColumnCnt[symbol]/len(dataSet)
        gini = gini+p_k**2
    gini = 1-gini
    return gini

#子数据集构建
def makeAttributeData(dataSet,value,iColumn):
    attributeData = []
    for element in dataSet:
        if element[iColumn]==value:
            row = element[:iColumn]
            row.extend(element[iColumn+1:])
            attributeData.append(row)
    return attributeData

#计算基尼指数
def GiniIndex(dataSet,iColumn):
    index = 0.0
    attribute = set([element[iColumn] for element in dataSet])
    for value in attribute:
        attributeData = makeAttributeData(dataSet,value,iColumn)
        index = index+len(attributeData)/len(dataSet)*calculateGini(attributeData)
    return index

#选择最优属性                
def selectOptimalAttribute(dataSet,labels):
    bestGini = []
    for iColumn in range(0,len(labels)):#不计最后的类别列
        index = GiniIndex(dataSet,iColumn)
        bestGini.append(index)
    sequence = bestGini.index(min(bestGini))
    return sequence
    
#建立决策树
def createTree(dataSet,labels):
    classification = targetClass(dataSet) #获取类别种类(集合去重)
    if len(classification) == 1:
        return list(classification)[0]
    if len(labels) == 1:
        return majorityRule(dataSet)#返回样本种类较多的类别
    sequence = selectOptimalAttribute(dataSet,labels)
    optimalAttribute = labels[sequence]
    del(labels[sequence])
    myTree = {optimalAttribute:{}}
    attribute = set([element[sequence] for element in dataSet])
    for value in attribute:
        subLabels = labels[:]
        myTree[optimalAttribute][value] =  \
                createTree(makeAttributeData(dataSet,value,sequence),subLabels)
    return myTree

#定义主函数
def main():
    filePath = 'watermelonData.xls'
    data = getData(filePath)
    dataSet = dataDeal(data)
    labels = getLabels(data)
    myTree = createTree(dataSet,labels)
    return myTree

#读取数据文件并转换为列表(含有汉字的,使用CSV格式读取容易出错)
if __name__ == '__main__':
    myTree = main()
    print (myTree)

 结果竟然是一样的,深度怀疑做错了。

posted @ 2017-12-01 18:49  君以沫  阅读(1187)  评论(0编辑  收藏  举报