李航——决策树代码

# -*- coding: utf-8 -*-
"""
Created on Tue May 15 15:28:42 2018

@author: baochen
"""
from math import log
import numpy
import operator

def CalEnt(database):
    LabelNumber = len(database)
    LabelDic = {}
    
    for line in database:
        LineLabel = line[-1]
        if LineLabel not in LabelDic.keys():
            LabelDic[LineLabel] = 0
        LabelDic[LineLabel] += 1
        
        
    ShanonEnt = 0.0
    for key in LabelDic:
        prob = LabelDic[key]/LabelNumber
        ShanonEnt -= prob*log(prob,2) 
        
    return ShanonEnt

def CreatDatabase():
    database =[[1,0,0,0,'no'],[1,0,0,1,'no'],[1,1,0,1,'yes'],[1,1,1,0,'yes'],[1,0,0,0,'no'],
               [2,0,0,0,'no'],[2,0,0,1,'no'],[2,1,1,1,'yes'],[2,0,1,2,'yes'],[2,0,1,2,'no'],
               [3,0,1,2,'yes'],[3,0,1,1,'no'],[3,1,0,1,'yes'],[3,1,0,2,'yes'],[3,0,0,0,'no']]
    return database
#axis表示维度,value表示区别
def SplitDatabase(database,axis,value):
    retDatabase = []
    for line in database:
        if line[axis] == value:
            retDatabase.append(line)            
    return retDatabase

def ChooseBest(database):
    baseShanon = CalEnt(database)
    bestInformationGain = 0.0
    ConShanon = 0.0
    FeaNum = len(database[0][:])-1
    #print(FeaNum)
    BestChoose = -1
    LabelNum = len(database[:][0])

    
    for i in range(FeaNum):
        
        if i not in t:
            
            FeatList = [temp[i] for temp in database]
            PureFeatList = set(FeatList)
       
            for value in PureFeatList:
                subdatabase = SplitDatabase(database,i,value)
                prob = len(subdatabase)/float(len(database))
                ConShanon -= prob*CalEnt(subdatabase) 
                InformationGain = baseShanon - ConShanon
            #print(i)
                if InformationGain > bestInformationGain:
                    bestInformationGain = InformationGain
                    BestChoose = i
               # print(BestChoose)
    return BestChoose
            
def majorityEnt(LabelList):
    LabelCount = {}
    for vote in LabelList:
        if vote not in LabelList.keys():
            LabelCount[vote] = 0
            
        LabelList[vote] += 1
        #sorted(iterable,cmp,key,reverse = true) 
        #第一个是迭代器,第二个是判断函数,第三个是分类数据,第四个是正序反序
        #因为第一个需要迭代器,所以我们生成迭代器,第二个不管,第三个operator.itemgetter表示按第一个域进行排序
    sortedLabelCount = sorted(LabelCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    #字典也可以当做数组的处理方法来搞
    return sortedLabelCount[0][0]
            
            

def CreatTree(database,label):
    LabelList = [x[-1] for x in database]
    #判断某一行是否全为某个数的方法,就是判断第一个数的个数是否等于该行的全部数目
    if LabelList.count(LabelList[0]) == len(LabelList):
        return LabelList[0]
    
    if len(database[0]) == 1:
        return majorityEnt(LabelList)
    
    
    Feature = ChooseBest(database)
    t.append(Feature)
    #print(t)
    
 #   print(Feature)
    BestLabel = label[Feature]
    print(BestLabel)

    
    
    MyTree = {BestLabel:{}}
    
    #del(label[Feature])
        #print(database)
    
    
    featValues = [x[Feature] for x in database]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = label
        MyTree[BestLabel][value] = CreatTree(SplitDatabase(database,Feature,value),subLabels)
        
    return MyTree
    
    

    
        


def test():
    global t 
    t = []
    mydatabase = CreatDatabase()
    '''
    s = CalEnt(mydatabase)
    t=SplitDatabase(mydatabase,0,1)
    print(s)
    print(t)
    '''
    label= ['age','work','house','money']
    print(CreatTree(mydatabase,label))
    
test()
money
age
work
house
money

{'money': {0: {'age': {1: {'work': {0: 'no', 1: 'yes'}}, 2: 'no', 3: 'no'}}, 1: {'house': {0: {'money': {'no': 'no', 'yes': 'yes'}}, 1: {'money': {'no': 'no', 'yes': 'yes'}}}}, 2: {'money': {'no': 'no', 'yes': 'yes'}}}}

感觉最后面的地方有点乱了

有空优化一下。

posted on 2018-05-16 15:36  maxwell_tesla  阅读(188)  评论(0编辑  收藏  举报