《机器学习实战》 | 第3章 决策树(含Matplotlib模块介绍)
系列文章:《机器学习实战》学习笔记
本篇文章使用到的完整代码:Here
决策树
- 优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
- 缺点:可能会产生过度匹配问题。
- 适用数据类型:离散型和连续型
我们经常使用决策树处理分类问题,它的过程类似二十个问题的游戏:参与游戏的一方在脑海里想某个事物,其他参与者向他提出问题,只允许提20个问题,问题的答案也只能用对或错回答。问问题的人通过推断分解,逐步缩小带猜测事物的范围。如图1所示的流程图就是一个决策树,长方形代表判断模块(decision block),椭圆形代表终止模块(terminating block),表示已经得出结论,可以终止运行。从判断模块引出的左右箭头称作分支(branch),它可以到达另一个判断模块或终止模块。
图1构造了一个假象的邮件分类系统,它首先检测发送邮件域名地址。如果地址为myEmployer.com,则将其放在分类"无聊时需要阅读的邮件"中。如果邮件不是来自这个域名,则检查内容是否包括单词曲棍球,如果包含则将邮件归类到"需要及时处理的朋友邮件",否则将邮件归类到"无须阅读的垃圾邮件"。
第2章介绍的k-近邻算法可以完成很多分类任务,但是它最大的缺点就是无法给出数据的内在含义,决策树的主要优势就在于数据形式非常容易理解。
本章构造的决策树算法能够读取数据集合,构建类似图1的决策树。决策树可以在数据集合中提取出一系列规则,规则创建的过程就是机器学习的过程。现在我们已经大致了解决策树可以完成哪些任务,接下来我们将学习如何从一堆原始数据中构造决策树。首先我们讨论构造决策树的方法,以及如何编写构造树的Python代码;接着提出一些度量算法成功率的方法;最后使用递归建立分类器。
一、决策树的构造
在构造决策树时,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。我们假设已经根据一定的方法选取了待划分的特征,则原始数据集将根据这个特征被划分为几个数据子集。这数据子集会分布在决策点(关键特征)的所有分支上。如果某个分支下的数据属于同一类型,则无需进一步对数据集进行分割。如果数据子集内的数据不属于同一类型,则需要递归地重复划分数据子集的过程,直到每个数据子集内的数据类型相同。
创建分支的过程用伪代码表示如下:
检测数据集中的每个子项是否属于同一类型:
如果是,则返回类型标签
否则:
寻找划分数据集的最好特征
划分数据集
创建分支节点
对划分的每个数据子集:
递归调用本算法并添加返回结果到分支节点中
返回分支节点
注:伪代码是一个递归函数。
决策树的一般流程:
- 收集数据:可以使用任何方法。
- 准备数据:树构造算法只适用于标称数据,因此数值型数据必须离散化。
- 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
- 训练算法:构造树的数据结构。
- 测试算法:使用经验树计算错误率。
- 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
一些决策树算法使用二分法划分数据,本书并不采用这种方法。如果依据某个属性划分数据将会产生4个可能的值,我们将把数据划分成四块,并创建四个不同的分支。
本书将使用ID3算法划分数据集,该算法处理如何划分数据集,何时停止划分数据集(进一步的信息可以参见http://en.wikipedia.org/wiki/ID3_algorithm)。每次划分数据集我们只选取一个特征属性,那么应该选择哪个特征作为划分的参考属性呢?
表1的数据包含5个海洋动物,特征包括:不浮出水面是否可以生存,以及是否有脚噗。我们可以将这些动物分成两类:鱼类和非鱼类。
表1 海洋生物数据
不浮出水面是否可以生存 | 是否有脚蹼 | 属于鱼类 | |
---|---|---|---|
1 | 是 | 是 | 是 |
2 | 是 | 是 | 是 |
3 | 是 | 否 | 否 |
4 | 否 | 是 | 否 |
5 | 否 | 是 | 否 |
1.1 信息增益
划分数据集的大原则是:将无序的数据变得更加有序。我们可以使用多种方法划分数据集,但是每种方法都有各自的优缺点。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。我们可以在划分数据之前或之后使用信息论量化度量信息的内容。
在划分数据集之前之后信息发生的变化成为信息增益,我们可以计算每个特征划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
对于某件事情
不确定性越大,熵越大,确定该事所需的信息量也越大;
不确定性越小,熵越小,确定该事所需的信息量也越小。(个人理解:将乱序数据转化为有序数据前后变化为信息增益,数据的信息的混乱程度叫熵)。
集合信息的度量方式成为香农熵或者简称为熵。
熵定义为信息的期望值。我们先确定信息的定义:
如果待分类的事务可能划分在多个分类之中,则符号 \(x_i\) 定义为:
其中 \(p(x_i)\) 是选择该分类的概率。
为了计算熵,我们需要计算所有类型所有可能值包含的信息的期望值,通过下面的公式得到:
其中 \(n\) 是分类的数目。
下面给出计算信息熵的 Python
函数,创建名为 trees.py
文件,添加如下代码:
from math import log
# H(x) = -\sum_{i = 1}^nP(X_i)log_2P(X_i)
def calsShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
# 为所有可能的字创建字典
for dataVec in dataSet:
label = dataVec[-1]
if label not in labelCounts.keys(): # 为所有可能分类创建字典
labelCounts[label] = 0
labelCounts[label] += 1
shannonEnt = 0.0
for key in labelCounts.keys():
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2) # 以2为底求对数
return shannonEnt
代码说明:
- 首先,计算数据集中实例的总数。我们可以在需要时再计算这个值,但是由于代码中多次用到这个值,为了提高代码效率,我们显式地声明一个变量保存实例总数。
- 然后,创建一个数据字典,它的键值是最后一列的数值。如果当前键值不存在,则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的粗疏。
- 最后,使用所有类标签的发生频率计算类别出现的概率。我们将用这个概率计算香农熵,统计所有类标签发生的次数。
在 trees.py
文件中,我们利用 createDateSet()
函数得到一些样例数据:
def creatDataSet():
dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no'],
]
labels = ['no surfacng', 'flippers']
return dataSet, labels
熵越高,则混合的数据也越多。得到熵之后,我们就可以按照获得最大信息增益的方法划分数据集。
另一个度量集合无序程度的方法是基尼不纯度(Gini impurity),简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。
1.2 划分数据集
我们将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集市最好的划分方法。
添加划分数据集的代码:
def splitDataSet(dataSet, axis, value):
retDataSet = [] # 创建新的list对象
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec) # 抽取
return retDataSet
该函数使用了三个输入参数:带划分的数据集、划分数据集的特征、需要返回的特征的值。函数先选取数据集中第axis个特征值为value的数据,从这部分数据中去除第axis个特征,并返回。
测试这个函数,效果如下:
>>> import trees
>>> myDat, labels = trees.createDataSet()
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> trees.splitDataSet(myDat,0,1)
[[1, 'yes'], [1, 'yes'], [0, 'no']]
>>> trees.splitDataSet(myDat,0,0)
[[1, 'no'], [1, 'no']]
接下来我们将遍历整个数据集,循环计算香农熵和 splitDataSet()
函数,找到最好的特征划分方式。
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calsShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calsShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
函数选取了第一个特征用于划分。
1.3 递归构建决策树
构造决策树所需的子功能模块已经介绍完毕,构建决策树的算法流程如下:
- 得到原始数据集,
- 基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。
- 第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。我们可以采用递归的原则处理数据集。
- 递归结束的条件是,程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。
参加图2所示:
在 trees.py
中添加下面的程序代码:
import operator
def majority(classList):
classCount = {}
for vote in classList:
if vote not in classCount.key(): classCount[vote] = 0
classCount[vote] += 1
sortedclassCount = sorted(classCount.iteritems(),
key=operator.itemgetter(1),
reverse=True)
return sortedclassCount[0][0]
# 创建树
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# 类型完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征时返回出现次数最多的
if len(dataSet[0]) == 1:
return majority(classList)
bestFeat = chooseBestFeatureToSplit(dataSet=dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del (labels[bestFeat])
# 得到列表包含的所有属性值
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
sublabels = labels[:] # 复制labels列表
myTree[bestFeatLabel][value] = createTree(
splitDataSet(dataSet, bestFeat, value), sublabels) # 递归构造子树
return myTree
majorityCnt
函数统计 classList
列表中每个类型标签出现频率,返回出现次数最多的分类名称。
createTree
函数使用两个输入参数:数据集 dataSet
和标签列表 labels
标签列表包含了数据集中所有特征的标签,算法本身并不需要这个变量,但是为了给出数据明确的含义,我们将它作为一个输入参数提供。
上述代码首先创建了名为 classList
的列表变量,其中包含了数据集的所有类标签。列表变量classList
包含了数据集的所有类标签。递归函数的第一个停止条件是所有类标签完全相同,则直接返回该类标签。递归函数的第二个停止条件是使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。这里使用 majorityCnt
函数挑选出现次数最多的类别作为返回值。
下一步程序开始创建树,这里直接使用 Python
的字典类型存储树的信息。字典变量 myTree
存储树的所有信息。当前数据集选取的最好特征存储在变量 bestFeat
中,得到列表中包含的所有属性值。
最后代码遍历当前选择特征包含的所有属性值,在每个数据集划分上递归待用函数 createTree()
,得到的返回值将被插入到字典变量 myTree
中,因此函数终止执行时,字典中将会嵌套很多代表叶子节点信息的字典数据。
注意其中的 subLabels = labels[:]
复制了类标签,因为在递归调用 createTree
函数中会改变标签列表的值。
测试这些函数:
>>> import trees
>>> myDat, labels = trees.createDataSet()
>>> myTree = trees.createTree(myDat,labels)
>>> myTree
{'no surfacng': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
二、使用 Matplotlib 注解绘制树形图
上节我们已经学习了如何从数据集中创建树,然而字典的表示形式非常不易于理解,而且直 接绘制图形也比较困难。本节我们将使用 Matplotlib
库创建树形图。决策树的主要优点就是直观 易于理解,如果不能将其直观地显示出来,就无法发挥其优势。虽然前面章节我们使用的图形库 已经非常强大,但是Python并没有提供绘制树的工具,因此我们必须自己绘制树形图。本节我们 将学习如何编写代码绘制如 图3 所示的决策树。
2.1 Matplotlib 注解
Matplotlib
提供了一个注解工具 annotations
,非常有用,它可以在数据图形上添加文本注 释。注解通常用于解释数据的内容。由于数据上面直接存在文本描述非常丑陋,因此工具内嵌支 持带箭头的划线工具,使得我们可以在其他恰当的地方指向数据位置,并在此处添加描述信息, 解释数据内容。如图4所示,在坐标 \((0.2, 0.1)\) 的位置有一个点,我们将对该点的描述信息放在 \((0.35, 0.3)\) 的位置,并用箭头指向数据点 \((0.2, 0.1)\)。
使用 Matplotlib
的注解功能绘制树形图,它可以对文字着色并提供多种形状以供选择, 而且我们还可以反转箭头,将它指向文本框而不是数据点。打开文本编辑器,创建名为 treePlotter.py
的新文件,然后输入下面的程序代码。
import matplotlib.pyplot as plt
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.axl.annotate(nodeTxt,
xy=parentPt,
xycoords='axes fraction',
xytext=centerPt,
textcoords='axes fraction',
va="center",
ha="center",
bbox=nodeType,
arrowprops=arrow_args)
# createPlot 版本一
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf() # 清空绘图区
createPlot.axl = plt.subplot(111, frameon=False)
plotNode(U'决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode(U'叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
createPlot()
基于这个例子,现在开始学习绘制整棵树。
2.2 构造注解树
绘制一棵完整的树需要一些技巧。我们虽然有 \(x,y\) 坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便可以正确确定 \(x\) 轴的长度;我们还需要知道树有多少层,以便可以正确确定 \(y\) 轴的高度。这里我们定义两个新函数 getNumLeafs()
和 getTreeDepth()
,来 获取叶节点的数目和树的层数,参见下面程序,并将这两个函数添加到文件 treePlotter.py
中。
这段代码有与原书不一样之处,原因在于Python版本不同。主要是以下两个方面:
- 1.firstStr 的创建不同:具体问题请点击:(firstStr创建问题)
- if判断语句不同:具体问题请点击:(if判断语句不同)
# 获取叶节点个数
def getNumLeafs(myTree):
numLeafs = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0] # 找到输入的第一个元素
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) == dict:
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 获取树的层数
def getTreeDepth(myTree):
maxDepth = 0
firstSides = list(myTree.keys())
firstStr = firstSides[0] # 找到输入的第一个元素
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) == dict:
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
上述程序中的两个函数具有相同的结构,后面我们也将使用到这两个函数。
这里使用的数据结构说明了如何在 Python
字典类型中存储树信息。第一个关键字是第一次划分数据集的类别标签,附带的数值表示子节点的取值。从第一个关键字出发,我们可以遍历整棵树的所有子节点。 使用Python提供的type()函数可以判断子节点是否为字典类型 。如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用 getNumLeafs()
函数。getNumLeafs()
函数遍历整棵树,累计叶子节点的个数,并返回该数值。第2个函数 getTreeDepth()
计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一。为了节省大家的时间,函数 retrieveTree
输出预先存储的树信息,避 免了每次测试代码时都要从数据中创建树的麻烦。 添加下面的代码到文件 treePlotter.py
中:
#输出预先存储的树信息,避免每次测试代码都从数据中创建树的麻烦
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
]
return listOfTrees[i]
print('retrieveTree(0) : \n{}'.format(retrieveTree(0)))
print('retrieveTree(1) : \n{}'.format(retrieveTree(1)))
myTree = retrieveTree(0)
print('树的叶子结点个数为:\n{}'.format(getNumLeafs(myTree)))
print('树的深度为: \n{}'.format(getTreeDepth(myTree)))
2.3.构造注解树
#在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.axl.text(xMid, yMid, txtString)
#画一棵树
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) #计算树的宽
depth = getTreeDepth(myTree) #计算树的高
firstStr = list(myTree.keys())[0]
plotTree.totalW = float(getNumLeafs(myTree)) #存储树的宽度
plotTree.totalD = float(getTreeDepth(myTree)) #存储树的深度
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
plotTree.yOff)
#cntrPt = (plotTree.xOff + (0.5/plotTree.totalW + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt) #标记子节点属性值
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]) == dict:
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt,
leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
#createPlot 版本二
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axpropps = dict(xticks=[], yticks=[])
createPlot.axl = plt.subplot(111, frameon=False, **axpropps)
plotTree.totalW = float(getNumLeafs(inTree)) #存储树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
plotTree.xOff = -0.5 / plotTree.totalW #xOff 与 yOff追踪已经绘制的节点位置以及下一个节点的恰当位置。
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
myTree = retrieveTree(0)
createPlot(myTree)
注:我在执行过程中发现,图像无法完全展示,所以我点击设置调整了图形大小及位置,调正后如下图。
2.4.变更字典
#在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.axl.text(xMid,yMid,txtString)
#画一棵树
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree) #计算树的宽
depth = getTreeDepth(myTree) #计算树的高
firstStr = list(myTree.keys())[0]
plotTree.totalW = float(getNumLeafs(myTree)) #存储树的宽度
plotTree.totalD = float(getTreeDepth(myTree)) #存储树的深度
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
#cntrPt = (plotTree.xOff + (0.5/plotTree.totalW + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt) #标记子节点属性值
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]) == dict:
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#createPlot 版本二
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axpropps = dict(xticks = [],yticks = [])
createPlot.axl = plt.subplot(111, frameon = False, **axpropps)
plotTree.totalW = float(getNumLeafs(inTree)) #存储树的宽度
plotTree.totalD = float(getTreeDepth(inTree)) #存储树的深度
plotTree.xOff = -0.5/plotTree.totalW #xOff 与 yOff追踪已经绘制的节点位置以及下一个节点的恰当位置。
plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
myTree = retrieveTree(0)
myTree['no surfacing'][3] = 'maybe'
print('myTree : \n{}'.format(myTree))
createPlot(myTree)
运行结果如下
myTree :
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
三、测试和存储分类器
3.1 测试算法:使用决策树进行分类
依靠训练数据构造了决策树之后,我们可以将它用于实际数据的分类。在执行数据分类时,需要决策树以及用于决策树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子结点;最后将测试数据定义为叶子结点所属的类型。
使用决策树分类的函数:添加进 trees.py
中
# 使用决策树
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
mydata, labels = createDataset()
mytree = createTree(mydata, labels)
print(classify(mytree, labels, [1, 1]))
# 程序报错ValueError: 'no surfacing' is not in list
# 因为createTree()函数中删除了最佳划分特征的标签 del(labels[bestFeat])
# 把 del(labels[bestFeat]) 注释掉便可以输出 yes
3.2 使用算法:决策树的存储
可以使用 Python
模块 pickle
序列化对象,参见下面的程序。序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
四、示例:使用决策树预测隐形眼镜类型
隐形眼镜数据集市非常著名的数据集,它包含很多患者眼部状态的观察条件以及医生推荐的因性眼睛类型。隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。数据来源于UCI数据库,为了更容易显示数据,本书对数据做了简单的更改,数据存储在源代码下载路径的文本文件中。
# 实例:使用决策树预测隐形眼镜类型
f = open('D:\Coding\Py\Machine-Learning\Decision-Tree\lenses.txt')
lenses = [line.strip().split('\t') for line in f.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print(lenses)
print(lensesLabels)
print(lensesTree)
# 绘制决策树
import treePlotter
treePlotter.createPlot(lensesTree)
决策树很好地匹配了实验数据,然而这些匹配选项可能太多了。我们将这种问题称之为过度匹配(overfitting)。为了减少过度匹配问题,我们可以裁剪决策树,去掉一些不必要的叶子结点。如果叶子结点只能增加少许信息,则可以删除该节点,将他并入到其他叶子结点中。第9章将进一步讨论这个问题。
第九章将学习另一个决策树构造算法CART,本章使用的算法成为ID3,它是一个好的算法但并不完美。ID3算法无法直接处理数值型数据,尽管我们可以通过量化的方法将数值型数据转化为标称型数据,但是如果存在太多的特征划分,ID3算法仍然会面临其他问题。
附录:
-
关于基尼不纯度(Gini impurity)的更多信息,请参考Pan-Ning Tan, Vipin Kumar and Michael Steinbach, Introduction to Data Mineing. Pearson Eduction (Addison-Wesley, 2005), 158.
-
隐形眼镜数据集:The dataset is a modified version of the Lenses dataset retrieved from the UCI Machine Learning Repository November 3, 2001 [http://archive.ics.uci.edu/ml/machine-learning-databases/lenses/]. The source of the data is Jadzia Cendrowska and was originally published in “PRISM: An algorithm for inducing modular rules,” in International Journal of Man-Machine Studies (1987), 27, 349-70. 本书使用的数据的下载链接在:[链接]。