您的位置:首页 > 编程语言 > Python开发

【ML学习笔记】23:用python绘制决策树

2018-02-02 21:06 363 查看
继续跟着白皮书学习,对上面的代码做了不少改动,现在能正确绘制了。

先不谈决策树的算法,现在仅仅是依据字典表示树来绘制决策树的图形。

go.py

引导脚本。

#!/usr/local/bin/python3.5
import treePlot
myTree0=treePlot.getTstTree(0)
myTree1=treePlot.getTstTree(1)
myTree0['no surfacing'][1]['flippers'][0]=myTree1['no surfacing'][0]
treePlot.mainPlot(myTree0)


treePlot.py

#!/usr/local/bin/python3.5
#-*-coding:utf-8-*-
import matplotlib.pyplot as plt

#建立存决策结点格式的字典{'fc': '0.8', 'boxstyle': 'sawtooth'}
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
#建立存叶结点格式的字典{'fc': '0.8', 'boxstyle': 'round4'}
leafNode=dict(boxstyle="round4",fc="0.8")
#建立存箭头格式的字典{'arrowstyle': '->'}
arrow_args=dict(arrowstyle='-')

#绘制结点(结点名称,结点位置,箭头起点,结点类型)
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
#下面的pyplot.annotate()用于做文本注释
#参数s:传注释的文本字符串nodeTxt
#参数xy:传被注释的坐标元组(x,y)
#参数xytext:传插入文本的坐标元组(x,y),如果和xy不一样,会产生箭头
#参数xycoords:指定传入的参数xy所依据的坐标系统
#这个参数的值取'axes fraction'时表示从左下角的坐标轴
#参数textcoords:指定传入的参数xytext所依据的坐标系统,规则同xycoords
#参数arrowprops:传入一个字典,如果字典中有键为arrowstyle的键值对,那么其对应的值可以指定箭头的类型
#这个参数中的键arrowstyle的值还可以取'-|>','-['等..用的时候查官方文档吧
createPlot.ax1.annotate(\
s=nodeTxt,\
xy=parentPt,\
xytext=centerPt,\
xycoords='axes fraction',\
textcoords='axes fraction',\
va="center",\
ha="center",\
bbox=nodeType,\
arrowprops=arrow_args)

#用来测试pyplot.annotate()绘制注释的函数
def createPlot():
#下面的pyplot.annotate()用于创建一个新绘图对象
#参数num:若不提供则创建一个新图形;若提供了存在的num值则返回其引用;否则创建它并在窗口标题上显示
#提供的num为数字会显示'Figure 数字';字符串会直接显示这个字符串
#参数facecolor:指定背景颜色,可以使用颜色名或16进制颜色
fig=plt.figure(num='绘制注释',facecolor='#99CC66')
fig.clf() #清除figure对象fig上的图形
#这里createPlot.ax1是对createPlot这个函数定义了一个属性ax1
#python可以用这种方式来实现全局变量
#在这个属性中,用plt.subplot创建子图并获取了这个子图的引用
#就可以在其它函数中通过访问该函数这个属性直接操作这个子图了
#frameon指定子图是否独立出来,默认是True
#子图不独立出来时,将继承figure对象的facecolor
createPlot.ax1=plt.subplot(111,frameon=True)
#调用绘制结点的函数,在函数体内用createPlot.ax1访问到此处建立的子图
plotNode('Decision node',(0.5,0.1),(0.1,0.5),decisionNode)
plotNode('Leaf nodes',(0.8,0.1),(0.3,0.8),leafNode)
plt.show()

#用来测试的已经建立好的字典形式的树
def getTstTree(i):
#只提供了两棵树
treeList=[\
{'no surfacing':\
{0:'no',1:\
{'flippers':\
{0:'no',1:'yes'}\
}\
}\
},\
{'no surfacing':\
{0:\
{'head':\
{0:'no',1:'yes'}\
},\
1:'no'\
}\
}\
]
return treeList[i]

#获取树myTree的叶结点数目
def getNumLeafs(myTree):
numLeafs=0 #初始化叶子结点数目为0
firstStr=list(myTree.keys())[0] #获取当前子树的树根key
secondDict=myTree[firstStr] #获取对应的所有可能划分的子树字典
for key in secondDict.keys(): #对于其中的每一个划分出的子树
#如果这棵子树下还有树,即其对应的value值还是一个字典对象
if type(secondDict[key]).__name__=='dict':
#将这个字典对象传入,递归调用求其叶结点数目加到总数中
numLeafs+=getNumLeafs(secondDict[key])
else: #如果这棵子树已经是叶结点了,即不再包含字典了
numLeafs+=1 #递归出口,记录叶结点数增加了1
return numLeafs #返回这个树下总的叶结点数目

#获取树myTree的高度
def getTreeDepth(myTree):
maxDepth=0 #初始化树的高度为0
firstStr=list(myTree.keys())[0] #获取当前子树的树根key
secondDict=myTree[firstStr] #获取对应的所有可能划分的子树字典
for key in secondDict.keys(): #对于其中的每一个划分出的子树
#如果这棵子树下还有树,即其对应的value值还是一个字典对象
if type(secondDict[key]).__name__=='dict':
#将这个字典对象传入,递归调用求其树高,加上子树根的高度1
thisDepth=getTreeDepth(secondDict[key])+1
else: #如果这棵子树已经是叶结点了,即不再包含字典了
thisDepth=1 #递归出口,记录单一结点的树高是1
if thisDepth>maxDepth: #如果这次找出的树高更高了
maxDepth=thisDepth #更新最高值
return maxDepth #返回这个树的树高

#在树的父子结点之间填充文本信息
#(子结点坐标[x,y],父结点坐标[x,y],文本信息字符串)
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]+cntrPt[0])/2.0 #横坐标中心
yMid=(parentPt[1]+cntrPt[1])/2.0 #纵坐标中心
#利用mainPlot函数的属性ax1在subplot对象上添加文本
mainPlot.ax1.text(xMid,yMid,txtString)

#在mainPlot函数的属性ax1对应的subplot对象上绘制结点
#(结点名称,结点位置,箭头起点,结点类型)
def mainPlotNode(nodeTxt,centerPt,parentPt,nodeType):
#下面的pyplot.annotate()用于做文本注释
#参数s:传注释的文本字符串nodeTxt
#参数xy:传被注释的坐标元组(x,y)
#参数xytext:传插入文本的坐标元组(x,y),如果和xy不一样,会产生箭头
#参数xycoords:指定传入的参数xy所依据的坐标系统
#这个参数的值取'axes fraction'时表示从左下角的坐标轴
#参数textcoords:指定传入的参数xytext所依据的坐标系统,规则同xycoords
#参数arrowprops:传入一个字典,如果字典中有键为arrowstyle的键值对,那么其对应的值可以指定箭头的类型
#这个参数中的键arrowstyle的值还可以取'-|>','-['等..用的时候查官方文档吧
mainPlot.ax1.annotate(\
s=nodeTxt,\
xy=parentPt,\
xytext=centerPt,\
xycoords='axes fraction',\
textcoords='axes fraction',\
va="center",\
ha="center",\
bbox=nodeType,\
arrowprops=arrow_args)

#绘制决策(子)树,也是一个递归的函数
#(字典表示树,父结点坐标[x,y],填充的文本信息)
def plotTree(myTree,parentPt,txtString):
numLeafs=getNumLeafs(myTree) #计算叶结点数(表征子树宽)
depth=getTreeDepth(myTree) #计算子树的高度
firstStr=list(myTree.keys())[0] #获取当前子树的树根key
#按比例计算树当前子树根结点的摆放位置
cntrPt=(plotTree.xOff+\
(1.0+float(numLeafs))/2.0/plotTree.totalW,\
plotTree.yOff)
#在树的父子结点之间填充文本信息
#(子结点坐标=当前根结点坐标,父结点坐标,填充的文本信息)
plotMidText(cntrPt,parentPt,txtString)
#绘制决策结点(结点名称,结点位置,箭头起点=父结点,结点类型=决策结点)
mainPlotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr] #获取对应的所有可能划分的子树字典
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD #为了绘制子树,y轴偏移量按比例减少
for key in secondDict.keys(): #对于其中的每一个划分出的子树
#如果这棵子树下还有树,即其对应的value值还是一个字典对象
if type(secondDict[key]).__name__=='dict':
#将这个字典对象传入,递归调用绘制其子树
#(子树下的字典对象,子树父结点坐标=当前根结点坐标,子树文本)
plotTree(secondDict[key],cntrPt,str(key))
#为了不影响兄弟结点的高度(yOff)
#在一个结点上的递归绘制子树完成返回之前
#需要把本层按比例减掉的yOff(全局纵坐标值)按比例加回来
#plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
else: #如果这棵子树已经是叶结点了,即不再包含字典了
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW #对子树,x轴偏移量按比例增加
#绘制新的结点(子树文本,结点位置=当前xy平移后的位置,箭头起点=当前树根结点位置,结点类型=叶子结点)
mainPlotNode(secondDict[key],\
(plotTree.xOff,plotTree.yOff),\
cntrPt,\
leafNode)
#在父子结点间填充文本信息.在绘制树时,这一步在plotTree里
#但在这里是绘制叶结点,没有封装进函数里而是单独拿出来做
#(子结点坐标=当前偏移后位置,父结点坐标=当前根结点位置,子树文本)
plotMidText((plotTree.xOff,plotTree.yOff),\
cntrPt,
str(key))

#绘制字典表示树inTree
def mainPlot(inTree):
fig=plt.figure('绘制决策树',facecolor='#CCCCFF')
fig.clf()
axprops=dict(xticks=[],yticks=[])
mainPlot.ax1=plt.subplot(111,frameon=False,**axprops)
#plotTree函数的totalW属性存储树的总宽度
plotTree.totalW=float(getNumLeafs(inTree))
#plotTree函数的totalD属性存储树的总高度
plotTree.totalD=float(getTreeDepth(inTree))
#初始化其xOff和yOff位置
#以使树inTree的根节点位置尽量合适
plotTree.xOff=-0.5/plotTree.totalW
plotTree.yOff=1.0
#在合适的位置选取虚拟的父节点
#以使树inTree的根节点位置尽量合适
plotTree(inTree,(0.5,1.0),'')
plt.show()


运行结果

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: