李航——决策树代码
# -*- coding: utf-8 -*- """ Created on Tue May 15 15:28:42 2018 @author: baochen """ from math import log import numpy import operator def CalEnt(database): LabelNumber = len(database) LabelDic = {} for line in database: LineLabel = line[-1] if LineLabel not in LabelDic.keys(): LabelDic[LineLabel] = 0 LabelDic[LineLabel] += 1 ShanonEnt = 0.0 for key in LabelDic: prob = LabelDic[key]/LabelNumber ShanonEnt -= prob*log(prob,2) return ShanonEnt def CreatDatabase(): database =[[1,0,0,0,'no'],[1,0,0,1,'no'],[1,1,0,1,'yes'],[1,1,1,0,'yes'],[1,0,0,0,'no'], [2,0,0,0,'no'],[2,0,0,1,'no'],[2,1,1,1,'yes'],[2,0,1,2,'yes'],[2,0,1,2,'no'], [3,0,1,2,'yes'],[3,0,1,1,'no'],[3,1,0,1,'yes'],[3,1,0,2,'yes'],[3,0,0,0,'no']] return database #axis表示维度,value表示区别 def SplitDatabase(database,axis,value): retDatabase = [] for line in database: if line[axis] == value: retDatabase.append(line) return retDatabase def ChooseBest(database): baseShanon = CalEnt(database) bestInformationGain = 0.0 ConShanon = 0.0 FeaNum = len(database[0][:])-1 #print(FeaNum) BestChoose = -1 LabelNum = len(database[:][0]) for i in range(FeaNum): if i not in t: FeatList = [temp[i] for temp in database] PureFeatList = set(FeatList) for value in PureFeatList: subdatabase = SplitDatabase(database,i,value) prob = len(subdatabase)/float(len(database)) ConShanon -= prob*CalEnt(subdatabase) InformationGain = baseShanon - ConShanon #print(i) if InformationGain > bestInformationGain: bestInformationGain = InformationGain BestChoose = i # print(BestChoose) return BestChoose def majorityEnt(LabelList): LabelCount = {} for vote in LabelList: if vote not in LabelList.keys(): LabelCount[vote] = 0 LabelList[vote] += 1 #sorted(iterable,cmp,key,reverse = true) #第一个是迭代器,第二个是判断函数,第三个是分类数据,第四个是正序反序 #因为第一个需要迭代器,所以我们生成迭代器,第二个不管,第三个operator.itemgetter表示按第一个域进行排序 sortedLabelCount = sorted(LabelCount.iteritems(), key=operator.itemgetter(1), reverse=True) #字典也可以当做数组的处理方法来搞 return sortedLabelCount[0][0] def CreatTree(database,label): LabelList = [x[-1] for x in database] #判断某一行是否全为某个数的方法,就是判断第一个数的个数是否等于该行的全部数目 if LabelList.count(LabelList[0]) == len(LabelList): return LabelList[0] if len(database[0]) == 1: return majorityEnt(LabelList) Feature = ChooseBest(database) t.append(Feature) #print(t) # print(Feature) BestLabel = label[Feature] print(BestLabel) MyTree = {BestLabel:{}} #del(label[Feature]) #print(database) featValues = [x[Feature] for x in database] uniqueVals = set(featValues) for value in uniqueVals: subLabels = label MyTree[BestLabel][value] = CreatTree(SplitDatabase(database,Feature,value),subLabels) return MyTree def test(): global t t = [] mydatabase = CreatDatabase() ''' s = CalEnt(mydatabase) t=SplitDatabase(mydatabase,0,1) print(s) print(t) ''' label= ['age','work','house','money'] print(CreatTree(mydatabase,label)) test()
money age work house money {'money': {0: {'age': {1: {'work': {0: 'no', 1: 'yes'}}, 2: 'no', 3: 'no'}}, 1: {'house': {0: {'money': {'no': 'no', 'yes': 'yes'}}, 1: {'money': {'no': 'no', 'yes': 'yes'}}}}, 2: {'money': {'no': 'no', 'yes': 'yes'}}}}
感觉最后面的地方有点乱了
有空优化一下。
posted on 2018-05-16 15:36 maxwell_tesla 阅读(188) 评论(0) 编辑 收藏 举报