【机器学习】决策树-02
心得体会:
1。使用字典树和matplotlib绘图
2.决策树可以用二进制方法‘wb+’存储到文本文件,用‘rb+’从文本文件提取
#3.2Matplotlib注解绘制树形图 #使用文本注解绘制树节点 import matplotlib import matplotlib.pyplot as plt decisionNode=dict(boxstyle="sawtooth",fc="0.8") #设置点 leafNode=dict(boxstyle="round4",fc="0.8") #设置点 arrow_args=dict(arrowstyle="<-") #设置箭头 #在图中添加这些点 def plotNode(nodeTxt,centerPt,parentPt,nodeType): #annotate是在plt的subplot上标记的函数 createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,\ textcoords='axes fraction',va="center",bbox=nodeType,arrowprops=arrow_args) # def createPlot(): # fig=plt.figure(1,facecolor='white')#图像编号1,背景色白色 # fig.clf() # Clear figure清除所有轴,但是窗口打开,这样它可以被重复使用 # createPlot.ax1=plt.subplot(111,frameon=False)# 1行1列,位置是1的子图——createPlot.ax1是plt子图的索引,可以通过ax1设计plt子图 # plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode) # plotNode('叶节点',(0.8,0.1),(0.0,0.0),leafNode) # plt.show() #注意:使用matplotlib时不要用qq输入法 # createPlot() #构造注解树 #获取叶节点的数目 def getNumLeafs(myTree): numLeafs=0 firstStr=list(myTree.keys())[0] secondDict=myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key])==dict: numLeafs+=getNumLeafs(secondDict[key]) else:numLeafs+=1 return numLeafs #获得树的层数 def getTreeDepth(myTree): maxDepth=0 firstStr=list(myTree.keys())[0] secondDict=myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key])==dict: thisDepth=1+getTreeDepth(secondDict[key]) else:thisDepth=1 if thisDepth>maxDepth:maxDepth=thisDepth return maxDepth #获得一颗树的数据 def retrieveTree(): myDat, labels = createDataSet() mytree = createTree(myDat, labels) return mytree # mytree=retrieveTree() # print(getNumLeafs(mytree)) # print(getTreeDepth(mytree)) #plotTree函数 def plotMidTest(cntrPt,parentPt,txtString): xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0] yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1] createPlot.ax1.text(xMid,yMid,txtString) def plotTree(myTree,parentPt,nodeTxt): numLeafs=getNumLeafs(myTree) depth=getTreeDepth(myTree) firstStr=list(myTree.keys())[0] cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW , plotTree.yOff) plotMidTest(cntrPt,parentPt,nodeTxt) plotNode(firstStr,cntrPt,parentPt,decisionNode) secondDict=myTree[firstStr] plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key])==dict: plotTree(secondDict[key],cntrPt,str(key)) else: plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,decisionNode) plotMidTest((plotTree.xOff,plotTree.yOff), cntrPt, str(key)) plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD def createPlot(inTree): fig=plt.figure(1,facecolor='white')#建立背景色白色图 fig.clf()#清除框架 axprops=dict(xticks=[],yticks=[]) createPlot.ax1=plt.subplot(111,frameon=False,**axprops)#生成子图 plotTree.totalW=float(getNumLeafs(inTree))#创建变量 plotTree.totalD=float(getTreeDepth(inTree))#创建变量 plotTree.xOff=-0.5/plotTree.totalW plotTree.yOff=1.0 plotTree(inTree,(0.5,1.0),'') plt.show() # createPlot(retrieveTree()) # 3-3测试和存储分类器 def classify(inputTree,featLabels,testVec):#testVec存储着对每个featLabel的答案 firstStr=list(inputTree.keys())[0] secondDict=inputTree[firstStr] featIndex=featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex]==key: if type(secondDict[key])==dict: classLabel=classify(secondDict[key],featLabels,testVec) else: classLabel=secondDict[key] return classLabel #使用算法:决策树的存储 def storeTree(inputTree,filename): import pickle fw=open(filename,'wb') #二进制存 pickle.dump(inputTree,fw) fw.close() def grabTree(filename): import pickle fr=open(filename,'rb') ##二进制取 return pickle.load(fr) # myTree=retrieveTree() # storeTree(myTree,"E:/Python/PycharmProjects/机器学习实战/Include/第03章_决策树/s.txt") # print(grabTree("E:/Python/PycharmProjects/机器学习实战/Include/第03章_决策树/s.txt"))
#示例:使用决策树预测隐形眼镜的类型
fr=open("E:/Python/《机器学习实战》代码/Ch03/lenses.txt")
lenses=[]
for data in fr.readlines():
lenses.append(data.strip().split('\t'))
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree=createTree(lenses,lensesLabels)
createPlot(lensesTree)