IMPLEMENTED IN PYTHON +1 | CART生成树

Introduction:

分类与回归树(classification and regression tree, CART)模型由Breiman等人在1984年提出,CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归,以下简要讨论树生成部分,在随后的博文中再探讨树剪枝的问题。

Algorithm:

step 1. 分别计算所有特征中各个分类的基尼系数

step 2. 选择有最小基尼系数的特征作为最优切分点,因$Gini(D,A_i=j)$最小,所以$A_i=j$作为最优切割点,$A_i$作为根节点 

step 3
. 在剩余的特征中重复step 1和2,获取最优特征及最优切割点,直至所有特征用尽或者是所有值都一一归类,最后所生成的决策树与ID3算法所生成的完全一致

Formula:

Code:

 1 """
 2 Created on Thu Jan 30 15:36:39 2014
 3 
 4 @filename: test.py
 5 """
 6 
 7 import cart
 8 
 9 c = cart.Cart()
10 c.trainDecisionTree('decision_tree_text.txt')
11 print c.trainresult
view test.py
  1 # -*- coding: utf-8 -*-
  2 """
  3 Created on Wed Jan 29 18:05:22 2014
  4 
  5 @filename: cart.py
  6 """
  7 FILENAME = 'decision_tree_text.txt'
  8 MAXDEPTH = 10
  9 
 10 import numpy as np
 11 import plottree
 12 
 13 class Cart():
 14     def __init__(self):
 15         self.trainresult = 'WARNING : please trainDecisionTree first!'
 16         pass
 17 
 18 
 19     def trainDecisionTree(self, filename):
 20         self.__loadDataSet(filename)
 21         self.__optimalTree(self.__datamat)
 22 
 23 
 24     def __loadDataSet(self, filename):
 25         fread   = open(filename)
 26         self.__dataset = np.array([row.strip().split('\t') \
 27                   for row in fread.readlines()])
 28         self.__textdic     = {}
 29         for col in self.__dataset.T:
 30             i = .0
 31             for cell in col:
 32                 if not self.__textdic.has_key(cell):
 33                     self.__textdic[cell] = i
 34                     i += 1
 35         self.__datamat = np.array([np.array([(lambda cell:self.__textdic[cell])(cell) \
 36                   for cell in row]) \
 37                   for row in self.__dataset])
 38 
 39 
 40     def __getSampleCount(self, setd, col = -1, s = None):
 41         dic = {}
 42 
 43         if s is not None:
 44             newset = self.__getSampleMat(setd,col,s)[:,-1]
 45         else:
 46             newset = setd[:,col]
 47 
 48         for cell in newset:
 49             if not dic.has_key(cell):
 50                 dic[cell] =  1.
 51             else:
 52                 dic[cell] += 1
 53         return dic
 54 
 55 
 56     def __getSampleMat(self, setd, col, s):
 57         lista = []; listb = []
 58         for row in setd:
 59             if row[col] == s:
 60                 lista.append(row)
 61             else:
 62                 listb.append(row)
 63         return  np.array(lista), np.array(listb)
 64 
 65 
 66     def __getGiniD(self, setd):
 67         sample_count = self.__getSampleCount(setd)
 68         gini = 0
 69         for item in sample_count.items():
 70             gini += item[1]/len(setd) * (1- item[1]/len(setd))
 71         return gini
 72 
 73 
 74     def __getGiniDA(self, setd, a):
 75         sample_count = self.__getSampleCount(setd, a)
 76         dic = {}
 77         for item in sample_count.items():
 78             setd_part_a, setd_part_b = self.__getSampleMat(setd, a, item[0])
 79             gini = item[1]/len(setd) * self.__getGiniD(setd_part_a) + \
 80                    (1- item[1]/len(setd)) * self.__getGiniD(setd_part_b)
 81             dic[item[0]]=gini
 82         return min(dic.items()), dic
 83 
 84 
 85     def __optimalNode(self, setd):
 86         coln    = 0
 87         ginicol = 0
 88         mingini = {1:1}
 89         for col in setd[:,:-1].T:
 90             gini, dic = self.__getGiniDA(setd, coln)
 91             if gini[1] < mingini[1]:
 92                 mingini = gini
 93                 ginicol = coln
 94             coln += 1
 95         return ginicol, mingini[0], mingini[1]
 96 
 97 
 98     def __optimalNodeText(self, col, value):
 99         row = 0
100         tex = None
101         for cell in self.__dataset.T[col]:
102             if self.__datamat[row,col] == value:
103                 tex = cell
104                 break
105             row += 1
106         return tex
107 
108 
109     def __optimalTree(self, setd):
110         arr      = setd
111         count    = MAXDEPTH-1
112         features = np.array(range(len(arr.T)))
113         lst      = []
114         defaultc = None
115         while count > 0:
116             count               -= 1
117             ginicol, value, gini = self.__optimalNode(arr)
118             parts                = self.__getSampleMat(arr, ginicol, value)
119             args                 = [np.unique(part[:,-1]) for part in parts]
120             realvalues           = [np.unique(part[:,ginicol])[0] for part in parts]
121             realcol              = features[ginicol]
122             features             = np.delete(features, ginicol)
123             if gini == 0 or len(arr.T) == 2:
124                 if args[0] == defaultc:
125                     value  = realvalues[0]
126                 else:
127                     value  = realvalues[1]
128                 self.trainresult = self.__buildList(lst, realcol, value, gini)
129                 self.__plotTree(self.trainresult)
130                 return
131             if len(args[0]) == 1:
132                 defaultc = args[0]
133                 self.__buildList(lst, realcol, realvalues[0], gini)
134                 arr = np.concatenate((parts[1][:,:ginicol], \
135                       parts[1][:,ginicol+1:]), axis=1)
136             else:
137                 defaultc = args[1]
138                 self.__buildList(lst, realcol, realvalues[1], gini)
139                 arr = np.concatenate((parts[0][:,:ginicol], \
140                       parts[0][:,ginicol+1:]), axis=1)
141 
142 
143     def __plotTree(self, lst):
144         dic = {}
145         for item in lst:
146             if dic == {}:
147                 dic[item[0]] = {item[1]:'c1','ELSE':'c2'}
148             else:
149                 dic = {item[0]:{item[1]:'c1','ELSE':dic}}
150         tree = plottree.retrieveTree(dic)
151         self.trainresult = tree
152         plottree.createPlot(tree)
153 
154 
155     def __buildList(self, lst, col, value, gini):
156         print 'feature col:', col, \
157                      ' feature val:', self.__optimalNodeText(col, value), \
158                      ' Gini:', gini, '\n'
159         lst.insert(0,[col,str(self.__optimalNodeText(col, \
160                    value))+':'+str(value)])
161         return lst
162 
163 
164 
165 if __name__ == '__main__':
166     cart = Cart()
view cart.py
 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Wed Jan 29 11:45:18 2014
 4 
 5 @filename: plottree.py
 6 """
 7 
 8 import matplotlib.pyplot as plt
 9 
10 decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
11 leafNode = dict(boxstyle = "round4", fc = "1.0")
12 arrow_args = dict(arrowstyle = "<-")
13 
14 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
15     createPlot.ax1.annotate(nodeTxt, xy = parentPt, \
16     xycoords = 'axes fraction', xytext = centerPt, \
17     textcoords = 'axes fraction', va= "center",\
18     ha = "center", bbox = nodeType, arrowprops =  arrow_args)
19 
20 def getNumLeafs(myTree):
21     numLeafs   = 0
22     firstStr   = myTree.keys()[0]
23     secondDict = myTree[firstStr]
24     for key in secondDict.keys():
25         if type(secondDict[key]).__name__ is 'dict':
26             numLeafs += getNumLeafs(secondDict[key])
27         else: numLeafs += 1
28     return numLeafs
29 
30 def getTreeDepth(myTree):
31     maxDepth   = 0
32     firstStr   = myTree.keys()[0]
33     secondDict = myTree[firstStr]
34     for key in secondDict.keys():
35         if type(secondDict[key]).__name__ == 'dict':
36             thisDepth = 1 + getTreeDepth(secondDict[key])
37         else: thisDepth = 1
38         if thisDepth > maxDepth: maxDepth = thisDepth
39     return maxDepth
40 
41 def retrieveTree(dic = {'have house': {'yes': 'c1', 'no':{'have job':  \
42                    {'yes': 'c1','no': 'c2'}}}}):
43     return dic
44 
45 def plotMidText(centrPt, parentPt, txtString):
46     xMid = (parentPt[0] - centrPt[0]) /2.0 + centrPt[0]
47     yMid = (parentPt[1] - centrPt[1]) /2.0 + centrPt[1]
48     createPlot.ax1.text(xMid, yMid, txtString)
49 
50 def plotTree(myTree, parentPt, nodeTxt):
51     numLeafs = getNumLeafs(myTree)
52     firstStr = myTree.keys()[0]
53     centrPt  = [plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
54                 plotTree.yOff]
55     plotMidText(centrPt, parentPt, nodeTxt)
56     plotNode(firstStr, centrPt, parentPt, decisionNode)
57     secondDict    = myTree[firstStr]
58     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
59     for key in secondDict.keys():
60         if type(secondDict[key]).__name__ == 'dict':
61             plotTree(secondDict[key], centrPt, str(key))
62         else:
63             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
64             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), \
65                      centrPt, leafNode)
66             plotMidText((plotTree.xOff, plotTree.yOff), centrPt, str(key))
67     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
68 
69 def createPlot(inTree):
70     fig              = plt.figure(1, facecolor = 'white')
71     fig.clf()
72     axprops          = dict(xticks = [], yticks = [])
73     createPlot.ax1   = plt.subplot(111, frameon = False, **axprops)
74     plotTree.totalW  = float(getNumLeafs(inTree))
75     plotTree.totalD  = float(getTreeDepth(inTree))
76     plotTree.xOff    = -0.5/plotTree.totalW
77     plotTree.yOff    = 1.0
78     plotTree(inTree, (0.5, 1.0), '')
79     plt.show()
80 
81 if __name__ == '__main__':
82     myTree = retrieveTree()
83     createPlot(myTree)
view plottree.py

输入数据

输出结果

feature col: 2  feature val: 是  Gini: 0.266666666667 

feature col: 1  feature val: 是  Gini: 0.0 

Reference:

Harrington P. Machine Learning in Action

李航. 统计学习方法

posted @ 2014-02-08 15:16  星矢  阅读(559)  评论(0编辑  收藏  举报