IMPLEMENTED IN PYTHON +1 | CART生成树
Introduction:
分类与回归树(classification and regression tree, CART)模型由Breiman等人在1984年提出,CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归,以下简要讨论树生成部分,在随后的博文中再探讨树剪枝的问题。
Algorithm:
step 1. 分别计算所有特征中各个分类的基尼系数 step 2. 选择有最小基尼系数的特征作为最优切分点,因$Gini(D,A_i=j)$最小,所以$A_i=j$作为最优切割点,$A_i$作为根节点
step 3. 在剩余的特征中重复step 1和2,获取最优特征及最优切割点,直至所有特征用尽或者是所有值都一一归类,最后所生成的决策树与ID3算法所生成的完全一致
Formula:
Code:
1 """ 2 Created on Thu Jan 30 15:36:39 2014 3 4 @filename: test.py 5 """ 6 7 import cart 8 9 c = cart.Cart() 10 c.trainDecisionTree('decision_tree_text.txt') 11 print c.trainresult
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Wed Jan 29 18:05:22 2014 4 5 @filename: cart.py 6 """ 7 FILENAME = 'decision_tree_text.txt' 8 MAXDEPTH = 10 9 10 import numpy as np 11 import plottree 12 13 class Cart(): 14 def __init__(self): 15 self.trainresult = 'WARNING : please trainDecisionTree first!' 16 pass 17 18 19 def trainDecisionTree(self, filename): 20 self.__loadDataSet(filename) 21 self.__optimalTree(self.__datamat) 22 23 24 def __loadDataSet(self, filename): 25 fread = open(filename) 26 self.__dataset = np.array([row.strip().split('\t') \ 27 for row in fread.readlines()]) 28 self.__textdic = {} 29 for col in self.__dataset.T: 30 i = .0 31 for cell in col: 32 if not self.__textdic.has_key(cell): 33 self.__textdic[cell] = i 34 i += 1 35 self.__datamat = np.array([np.array([(lambda cell:self.__textdic[cell])(cell) \ 36 for cell in row]) \ 37 for row in self.__dataset]) 38 39 40 def __getSampleCount(self, setd, col = -1, s = None): 41 dic = {} 42 43 if s is not None: 44 newset = self.__getSampleMat(setd,col,s)[:,-1] 45 else: 46 newset = setd[:,col] 47 48 for cell in newset: 49 if not dic.has_key(cell): 50 dic[cell] = 1. 51 else: 52 dic[cell] += 1 53 return dic 54 55 56 def __getSampleMat(self, setd, col, s): 57 lista = []; listb = [] 58 for row in setd: 59 if row[col] == s: 60 lista.append(row) 61 else: 62 listb.append(row) 63 return np.array(lista), np.array(listb) 64 65 66 def __getGiniD(self, setd): 67 sample_count = self.__getSampleCount(setd) 68 gini = 0 69 for item in sample_count.items(): 70 gini += item[1]/len(setd) * (1- item[1]/len(setd)) 71 return gini 72 73 74 def __getGiniDA(self, setd, a): 75 sample_count = self.__getSampleCount(setd, a) 76 dic = {} 77 for item in sample_count.items(): 78 setd_part_a, setd_part_b = self.__getSampleMat(setd, a, item[0]) 79 gini = item[1]/len(setd) * self.__getGiniD(setd_part_a) + \ 80 (1- item[1]/len(setd)) * self.__getGiniD(setd_part_b) 81 dic[item[0]]=gini 82 return min(dic.items()), dic 83 84 85 def __optimalNode(self, setd): 86 coln = 0 87 ginicol = 0 88 mingini = {1:1} 89 for col in setd[:,:-1].T: 90 gini, dic = self.__getGiniDA(setd, coln) 91 if gini[1] < mingini[1]: 92 mingini = gini 93 ginicol = coln 94 coln += 1 95 return ginicol, mingini[0], mingini[1] 96 97 98 def __optimalNodeText(self, col, value): 99 row = 0 100 tex = None 101 for cell in self.__dataset.T[col]: 102 if self.__datamat[row,col] == value: 103 tex = cell 104 break 105 row += 1 106 return tex 107 108 109 def __optimalTree(self, setd): 110 arr = setd 111 count = MAXDEPTH-1 112 features = np.array(range(len(arr.T))) 113 lst = [] 114 defaultc = None 115 while count > 0: 116 count -= 1 117 ginicol, value, gini = self.__optimalNode(arr) 118 parts = self.__getSampleMat(arr, ginicol, value) 119 args = [np.unique(part[:,-1]) for part in parts] 120 realvalues = [np.unique(part[:,ginicol])[0] for part in parts] 121 realcol = features[ginicol] 122 features = np.delete(features, ginicol) 123 if gini == 0 or len(arr.T) == 2: 124 if args[0] == defaultc: 125 value = realvalues[0] 126 else: 127 value = realvalues[1] 128 self.trainresult = self.__buildList(lst, realcol, value, gini) 129 self.__plotTree(self.trainresult) 130 return 131 if len(args[0]) == 1: 132 defaultc = args[0] 133 self.__buildList(lst, realcol, realvalues[0], gini) 134 arr = np.concatenate((parts[1][:,:ginicol], \ 135 parts[1][:,ginicol+1:]), axis=1) 136 else: 137 defaultc = args[1] 138 self.__buildList(lst, realcol, realvalues[1], gini) 139 arr = np.concatenate((parts[0][:,:ginicol], \ 140 parts[0][:,ginicol+1:]), axis=1) 141 142 143 def __plotTree(self, lst): 144 dic = {} 145 for item in lst: 146 if dic == {}: 147 dic[item[0]] = {item[1]:'c1','ELSE':'c2'} 148 else: 149 dic = {item[0]:{item[1]:'c1','ELSE':dic}} 150 tree = plottree.retrieveTree(dic) 151 self.trainresult = tree 152 plottree.createPlot(tree) 153 154 155 def __buildList(self, lst, col, value, gini): 156 print 'feature col:', col, \ 157 ' feature val:', self.__optimalNodeText(col, value), \ 158 ' Gini:', gini, '\n' 159 lst.insert(0,[col,str(self.__optimalNodeText(col, \ 160 value))+':'+str(value)]) 161 return lst 162 163 164 165 if __name__ == '__main__': 166 cart = Cart()
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Wed Jan 29 11:45:18 2014 4 5 @filename: plottree.py 6 """ 7 8 import matplotlib.pyplot as plt 9 10 decisionNode = dict(boxstyle = "sawtooth", fc = "0.8") 11 leafNode = dict(boxstyle = "round4", fc = "1.0") 12 arrow_args = dict(arrowstyle = "<-") 13 14 def plotNode(nodeTxt, centerPt, parentPt, nodeType): 15 createPlot.ax1.annotate(nodeTxt, xy = parentPt, \ 16 xycoords = 'axes fraction', xytext = centerPt, \ 17 textcoords = 'axes fraction', va= "center",\ 18 ha = "center", bbox = nodeType, arrowprops = arrow_args) 19 20 def getNumLeafs(myTree): 21 numLeafs = 0 22 firstStr = myTree.keys()[0] 23 secondDict = myTree[firstStr] 24 for key in secondDict.keys(): 25 if type(secondDict[key]).__name__ is 'dict': 26 numLeafs += getNumLeafs(secondDict[key]) 27 else: numLeafs += 1 28 return numLeafs 29 30 def getTreeDepth(myTree): 31 maxDepth = 0 32 firstStr = myTree.keys()[0] 33 secondDict = myTree[firstStr] 34 for key in secondDict.keys(): 35 if type(secondDict[key]).__name__ == 'dict': 36 thisDepth = 1 + getTreeDepth(secondDict[key]) 37 else: thisDepth = 1 38 if thisDepth > maxDepth: maxDepth = thisDepth 39 return maxDepth 40 41 def retrieveTree(dic = {'have house': {'yes': 'c1', 'no':{'have job': \ 42 {'yes': 'c1','no': 'c2'}}}}): 43 return dic 44 45 def plotMidText(centrPt, parentPt, txtString): 46 xMid = (parentPt[0] - centrPt[0]) /2.0 + centrPt[0] 47 yMid = (parentPt[1] - centrPt[1]) /2.0 + centrPt[1] 48 createPlot.ax1.text(xMid, yMid, txtString) 49 50 def plotTree(myTree, parentPt, nodeTxt): 51 numLeafs = getNumLeafs(myTree) 52 firstStr = myTree.keys()[0] 53 centrPt = [plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \ 54 plotTree.yOff] 55 plotMidText(centrPt, parentPt, nodeTxt) 56 plotNode(firstStr, centrPt, parentPt, decisionNode) 57 secondDict = myTree[firstStr] 58 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 59 for key in secondDict.keys(): 60 if type(secondDict[key]).__name__ == 'dict': 61 plotTree(secondDict[key], centrPt, str(key)) 62 else: 63 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 64 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \ 65 centrPt, leafNode) 66 plotMidText((plotTree.xOff, plotTree.yOff), centrPt, str(key)) 67 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD 68 69 def createPlot(inTree): 70 fig = plt.figure(1, facecolor = 'white') 71 fig.clf() 72 axprops = dict(xticks = [], yticks = []) 73 createPlot.ax1 = plt.subplot(111, frameon = False, **axprops) 74 plotTree.totalW = float(getNumLeafs(inTree)) 75 plotTree.totalD = float(getTreeDepth(inTree)) 76 plotTree.xOff = -0.5/plotTree.totalW 77 plotTree.yOff = 1.0 78 plotTree(inTree, (0.5, 1.0), '') 79 plt.show() 80 81 if __name__ == '__main__': 82 myTree = retrieveTree() 83 createPlot(myTree)
输入数据
输出结果
feature col: 2 feature val: 是 Gini: 0.266666666667
feature col: 1 feature val: 是 Gini: 0.0
Reference:
Harrington P. Machine Learning in Action
李航. 统计学习方法