决策树 绘图

import matplotlib.pyplot as plt

decisionNode=dict(boxstyle="sawtooth",fc="11.8");
leafNode=dict(boxstyle="round4",fc="0.8");
arrow_args=dict(arrowstyle="<-");

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args);

def createPlot():
    fig=plt.figure(1,facecolor='white');
    fig.clf();
    createPlot.ax1=plt.subplot(111,frameon=False);
    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode);
    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode);
    plt.show();

def getNumLeafs(myTree):
    numLeafs=0;
    firstStr=myTree.keys()[0];
    print firstStr;
    secondDict=myTree[firstStr];
    print secondDict;
    for key in secondDict.keys():
        print key;
        print secondDict[key];
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key]);
        else:
            numLeafs+=1;
    return numLeafs;

def getTreeDepth(myTree):
    maxDepth=0;
    firstStr=myTree.keys()[0];
    secondDict=myTree[firstStr];
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key]);
        else:
            thisDepth=1;
        if thisDepth >maxDepth:
            maxDepth=thisDepth;
    return maxDepth;

def retrieveTree(i):
    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}];
    return listOfTrees[i];

def plotMidText(cntrPt,parentPt,txtString): #计算父亲和孩子中间地方 放文本
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0];
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1];
    createPlot.ax1.text(xMid,yMid,txtString);

def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree);
    depth=getTreeDepth(myTree);
    firstStr=myTree.keys()[0];
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff);
    plotMidText(cntrPt,parentPt,nodeTxt); #画中间
    plotNode(firstStr,cntrPt,parentPt,decisionNode);#画这个点
    secondDict=myTree[firstStr];
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD;
    for key in secondDict.keys(): #不是叶子
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key));
        else:
            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW; #是叶子
            plotNode(secondDict[key],(plotTree.xOff,plot.yOff),cntrPt,leafNode);
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key));
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD;

def createPlot(inTree):
    fig=plt.figure(1,facecolor='white');
    fig.clf();
    axprops=dict(xticks=[],yticks=[]);
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops);
    plotTree.totalW=float(getNumLeafs(inTree));
    plotTree.totalD=float(getTreeDepth(inTree));
    plotTree.xOff=-0.5/plotTree.totalW; #个人感觉这边是有点神奇的
    plotTree.yOff=1.0;
    print plotTree.xOff,plotTree.yOff;
    plotTree(inTree,(0.5,1.0),'');
    plt.show();
View Code

 

posted on 2018-03-09 13:55  HelloWorld!--By-MJY  阅读(177)  评论(0编辑  收藏  举报

导航