python ID3决策树实现
环境:ubuntu 16.04 python 3.6
数据来源:UCI wine_data(比较经典的酒数据)
决策树要点:
1、 如何确定分裂点(CART ID3 C4.5算法有着对应的分裂计算方式)
2、 如何处理不连续的数据,如果处理缺失的数据
3、 剪枝处理
尝试实现算法一是为了熟悉python,二是为了更好的去理解算法的一个流程以及一些要点的处理。
from math import log import operator import pickle import os import numpy as np def debug(value_name,value): print("debuging for %s" % value_name) print(value) # feature map and wind_label def loadDateset(): with open('./wine.data') as f: wine = [eaxm.strip().split(',') for eaxm in f.readlines()] #for i in range(len(wine)): # wine[i] = list(map(float,wine[i])) wine = np.array(wine) wine_label = wine[...,:1] wine_data = wine[...,1:] # get the map of wine_feature featLabels = [] for i in range(len(wine_data)): #print(i) featLabels.append(i) # wine_data = np.concatenate((wine_data,wine_label),axis=1) # 这里的label需要做一定的修改 需要的label是属性对应的字典 return wine_data,featLabels # wine_data = dateset[:-1] wine_label = dateset[-1:] def informationEntropy(dataSet): m = len(dataSet) labelMap = {} for wine in dataSet: nowLabel = wine[-1] if nowLabel not in labelMap.keys(): labelMap[nowLabel] = 0 labelMap[nowLabel] += 1 shannoEnt = 0.0 for key in labelMap.keys(): prop = float(labelMap[key]/m) shannoEnt -= prop*(log(prop,2)) return shannoEnt # split the subDataSet Improve reusability def splitDataSet(dataSet,axis,feature): subDataSet = [] # date type for featVec in dataSet: if(featVec[axis] == feature): reduceVec = featVec[:axis] if(isinstance(reduceVec,np.ndarray)): reduceVec = np.ndarray.tolist(reduceVec) reduceVec.extend(featVec[axis+1:]) subDataSet.append(reduceVec) return subDataSet # choose the best Feature to split def chooseFeature(dataSet): numFeature = len(dataSet[0])-1 baseEntorpy = informationEntropy(dataSet) bestInfoGain = 0.0 bestFeature = -1 for i in range(numFeature): #valueList = wine_data[:,i:i+1] valueList = [value[i] for value in dataSet] # debug # print("valueList is:") # print(len(valueList)) uniqueVals = set(valueList) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet,i,value) #debug #print("subDataSet is :") #print(subDataSet) #print(len(subDataSet[0])) # 数值部分要注意 prop = len(subDataSet)/float(len(dataSet)) newEntropy += prop*informationEntropy(subDataSet) infoGain = baseEntorpy - newEntropy if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature def majorityCnt(classList): classMap = {} for vote in classList: if vote not in classMap.keys(): classMap[vote] = 0 classMap[vote] += 1 #tempMap = sorted(classMap.items(),key = operator.itemgetter(1),reverse = True) tempMap = sorted(classMap.items(), key=lambda x:x[1], reverse=True) return tempMap[0][0] # labels for map of Feature def createTree(dataSet,Featlabels): classList = [example[-1] for example in dataSet] # if all of the attribute of classList is same if(classList.count(classList[0])) == len(classList): #print("all is same") return classList[0] # print("debug after") # feature is empty if len(dataSet[0]) == 1: print("len is zero") return majorityCnt(classList) # print("debug pre") bestFeat = chooseFeature(dataSet) #debug #print("debug") #print(bestFeat) bestFeatLabel = Featlabels[bestFeat] # print(bestFeatLabel) # python tree use dict for index of feature to build the tree myTree = {bestFeatLabel:{}} # del redundant label del(Featlabels[bestFeat]) valueList = [example[bestFeat] for example in dataSet] uniqueVals = set(valueList) # print(uniqueVals) # 取值都一样的话就没有必要继续划分 if(len(uniqueVals) == 1): return majorityCnt(dataSet) for value in uniqueVals: #if(bestFeat == 6): # print(value) subFeatLabels = Featlabels[:] # print(sublabels) subdataSet = splitDataSet(dataSet,bestFeat,value) if(bestFeatLabel == 6 and value == '3.06'): #print("debuging ") myTree[bestFeatLabel][value] = createTree(subdataSet, subFeatLabels) #print(myTree[bestFeatLabel][value]) #print("len of build") #print(len(uniqueVals)) # print(value) else: myTree[bestFeatLabel][value] = createTree(subdataSet,subFeatLabels) return myTree # classity fuction featLabel and testVes is used to get featvalue of test def classify(inputTree,featLabels,testVec): # get the node nowNode = list(inputTree.keys())[0] # debug #debug(nowNode) # print(featLabels) featIndex = featLabels.index(nowNode) # print(featIndex) #find the value of testVec in feature keyValue = testVec[featIndex] #print("len of input") #print(len(inputTree[nowNode].keys())) keyValue = str(keyValue) subTree = inputTree[nowNode][keyValue] if(isinstance(subTree,dict)): classLabel = classify(subTree,featLabels,testVec) else: classLabel = subTree return classLabel if __name__ == '__main__': wine_data, featLabels = loadDateset() #print(featLabels) #print(wine_data) myTree = createTree(wine_data,featLabels.copy()) #print(type(myTree)) # the type of value test = [14.23,1.71,2.43,15.6,127,2.8,3.06,.28,2.29,5.64,1.04,3.92,1065] #print(featLabels) print(classify(myTree,featLabels,test))
静下来,你想要的东西才能看见