您的位置:首页 > 编程语言 > Python开发

Python学习-机器学习实战-ch03 Decision Tree_Part2

2016-03-27 22:38 821 查看
前一篇文章主要讲述了第三章的前部分,主要是决策树的算法实现部分

这一部分,因为篇幅较长,我们单独将他们分出来。包括:决策树可视化和分类预测

=======================================================================

matplotlib的文本注释绘制树节点

使用了Matplotlib中的annotate工具

import matplotlib.pyplot as plt

decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#设置三种类型

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
#四个参数应该对应的的是:节点文本、中心坐标、父节点坐标、节点类型
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,\
textcoords='axes fraction',va="center",ha="center",bbox=nodeType,\
arrowprops=arrow_args)

def createPlot():
fig=plt.figure(1,facecolor='white')
fig.clf()
createPlot.ax1=plt.subplot(111,frameon=False)
plotNode(U'决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
plotNode(U'叶节点',(0.8,0.1),(0.3,0.8),leafNode)
#绘制两个节点
plt.show()




节点文本没有显示出来

获取叶子节点

方法与学习数据结构中的二叉树叶子节点的方法一样。只是决策树内一个节点有很多个分支
def getNumLeafs(myTree):
numLeafs=0
firstStr=list(myTree.keys())[0]
#对于本例中的存储方式,key如果是值格式则为叶子节点,若是字典则为某子树的根节点
#可以以递归的方式求二叉树叶节点
secondDict=myTree[firstStr]
for key in list(secondDict.keys()):
if isinstance(secondDict[key],dict):
numLeafs+=getNumLeafs(secondDict[key])
#如果不是叶节点,则以该节点为根节点的子树进入循环
else:
numLeafs+=1
return numLeafs
在此有个地方跟书中不一样:myTree.keys()[0]返回的并不是列表行的不能有Index,所以改为了list(myTree.keys())[0]

获取树的层数

def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in list(secondDict.keys()):
if isinstance(secondDict[key],dict):
thisDepth=1+getTreeDepth(secondDict[key])
#对于非叶子节点,进一步迭代求其深度
else:
thisDepth=1
#叶子节点
if thisDepth>maxDepth:maxDepth=thisDepth
#遍历该层的每个节点,选出该层节点中的最大层数
return maxDepth


还是跟二叉树的方法一样,只是二叉树中返回的是左右子树的最大层数,此处决策树并不是二叉树而是多叉树,所以是用maxDepth动态选出最大层数

def retrieveTree(i):
listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
return listOfTrees[i]


生成一棵现成的树,不用每次试验都构建

def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=list(myTree.keys())[0]
cntrPt=(plotTree.xoff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yoff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.yoff=plotTree.yoff-1.0/plotTree.totalD
for key in secondDict.keys():
if isinstance(secondDict[key],dict):
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xoff=plotTree.xoff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xoff,plotTree.yoff),cntrPt,leafNode)
plotMidText((plotTree.xoff,plotTree.yoff),cntrPt,str(key))
plotTree.yoff=plotTree.yoff+1.0/plotTree.totalD

def createPlot(inTree):
#主函数,调用了plotTree
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xoff=-0.5/plotTree.totalW
plotTree.yoff=1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()


是之前画树小例子的升级版。其中type(***).__name__我用的时候一直报错
我以为 在python 3.X版中type与原来的方法 不一样了。
所以改成:isinstance(secondDict[key],dict)

结果后面发现是.__name__被我打成了._name_ ……
是两个下划线不是一个下划线!!!



使用决策树做分类:
def classify(inputTree,featLabels,testVec):
#分类函数输入的参数有:以构建好的决策树、每个特征维对应的标签、用于分类的向量
firstStr=list(inputTree.keys())[0]
secondDict=inputTree[firstStr]
featIndex=featLabels.index(firstStr)
#由特征标签获得其对应的下标,用于取测试向量中的对应值
for key in secondDict.keys():
if testVec[featIndex]==key:
#如果当前分支与测试向量对应维度值相同
if type(secondDict[key]).__name__=='dict':
#如果该分支对应还有子树,则递归进入下一层子树
classLabel=classify(secondDict[key],featLabels,testVec)
else:
#如果对应分支为叶子节点,则所对应的值为分类结果
classLabel=secondDict[key]
return classLabel


测试代码:
import trees
import treePlotter
myDAT,labels=trees.createDataset()
myTree=treePlotter.retrieveTree(0)
print(trees.classify(myTree,labels,[1,0]))
#该例中输入的样本特征为[1,0]



存储决策树:

def storeTree(inputTree,filename):
import pickle
fw=open(filename,'wb')
pickle.dump(inputTree,fw)
fw.close()

def grabTree(filename):
import pickle
fr=open(filename,'rb')
return pickle.load(fr)


此处修改两个地方:
fw=open(filename,'wb')


fr=open(filename,'rb')

将其以Byte的形式存储

用决策树预测隐形眼镜类型:

fr=open('lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
#读取文本数据并进行处理
lensesLabel=['age','prescript','astigmatic','tearRate']
#该数据集有四个特征维度
lensesTree=trees.createTree(lenses,lensesLabel)
#建立决策树
print(lensesTree)
treePlotter.createPlot(lensesTree)
#画出决策树




============================================================================================================

决策树这一章终于结束啦,后面的内容会越来越难,坚持!加油!

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: