意思決定ツリー-機械学習実戦完全版(python 3)
3697 ワード
import matplotlib.pyplot as plt
# boxstyle fc sawtooth
'''xy
xytext
: , ?
:arrowstyle="maxDepth: maxDepth=thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\
{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
return listOfTrees[i]
def plotMidText(cntrPt,parentPt,txtString):# cntrPt parentPt , txtString #parentPt ( ),cntrPt ,
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]#x
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]#y
createPlot.ax1.text(xMid,yMid,txtString)# (xMid, yMid) txtString
def plotTree(myTree,parentPt,nodeTxt): # nodeTxt
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstSides = list(myTree.keys())
firstStr = firstSides[0]
cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#cntrPt
# plotTree.xOff plotTree.yOff ,plotTree.totalW ,
#cntrPt plotTree ,
# cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
# numLeafs 。
# plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW
# 1.0 + numLeafs ,
# plotTree.xOff + float(numLeafs) / 2.0 / plotTree.totalW + 1.0 / 2.0 / plotTree.totalW
# plotTree.xOff + 1 / 2 * float(numLeafs) / plotTree.totalW + 0.5 / plotTree.totalW
# xOff - 0.5 / plotTree.totalW , 0.5 / plotTree.tatalW ,
# 。 cntrPt x '''
plotMidText(cntrPt,parentPt,nodeTxt)#
plotNode(firstStr,cntrPt,parentPt,decisionNode)#firstStr ,cntrPt ,
# parentPt ,decisionNode
secondDict=myTree[firstStr]#
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD#totalD , , 1
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW#x
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)#
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))#
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD#
def createPlot(inTree):# , plotTree,plotTree
fig=plt.figure(1,facecolor='white')# ,
fig.clf()#
axprops=dict(xticks=[],ytichs=[])# ,
createPlot.ax1 = plt.subplot(111, frameon=False)
#createPlot.ax1=plt.subplot(111,frameon=False,**axprops)# x、y ,xticks=[],ytichs=[] , #**
# ax1 createPlot ,
# createPlot.ax1 = plt.subplot(111, frameon = False, **axprops) #frameon , ,111 1X1 ,
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xOff=-0.5/plotTree.totalW# 1/totalW , 2/totalW, 3/totalW, …, 1 ,
# , , 0.5 / totalW 。
# createPlot , plotTree.xOff -0.5/plotTree.totalW。
# xOff + 1/totalW , 1
plotTree.yOff=1.0 #yOff 1, , 1 / totalD
plotTree(inTree,(0.5,1.0),'')
plt.show()