您的位置:首页 > 编程语言 > Python开发

决策树--ID3

2016-07-06 16:57 357 查看
机器学习实战第三章学习笔记使用决策树分类的一般步骤:准备数据à构建决策树(该决策树可以被存成文件,便于二次利用)à使用决策树决策【准备数据】有n组数据,每一组有m个元素,前m-1个元素为分类标签即特征,最后一个元素为决策结果。【构建决策树】首先选择最优的分类特征。输入:数据集dataSet,假设该数据集有k个特征,1个决策结果。方法:从第一个特征到第k个特征,每次使用一个特征去划分数据集dataSet成dataSet1、dataSet2、dataSet3…。挨个计算每一个子数据集的熵值n1、n2、n3…,并求和(dataSet1/dataset)*n1+(dataSet2/dataset)*n2+(dataSet3/dataset)*n3+….。总的熵值最小的那个特征便是最优分类特征。ps:所有代码的依赖包from math import logimport operatorimport matplotlib.pyplot as pltimport numpy[具体代码如下]
########
##计算输入数据集的熵值
########
def calcShannonEnt(dataSet):   #####计算数据集dataSet的熵
numEntries = len(dataSet)   ###统计数据集的行数
labelCounts = {}
for featVec in dataSet :
currentLabel = featVec[-1]    ####提取数据集中最后一列的内容,即数据标签。除去数据标签外的每一列都是该行数据的一个特征
if currentLabel not in labelCounts.keys():  ###统计每一种标签出现的次数,并将其存储在一个字典中格式为:{标签名, 出现次数}
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1

shannonEnt = 0.0
for key in labelCounts:  ###计算当前数据集的熵
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
使用key-value格式的数据结构创建决策树。输入:数据集dataSet,  标签序列方法:从数据集dataSet中选择最优分类特征feature1。以该特征值为key创建空的键值对。通过在数据集中计算该feature1所在列有多少个取值,就可以知道以该feature1为父结点的子节点有多少个。假设该feature1有i个子节点值分别为temp1、temp2、temp3…tempi,则循环i次,每次从当前数据集dataSet中删除feature1所在列值为tmpi行,并以余下的数据作为一个新的数据集创建子树。【递归调用】[具体代码如下]
########
##从数据集dataSet中删除第axis列,值为value的那行,并返回剩余的数据集
########
def splitDateSet(dataSet, axis, value):     ###从数据集合dataSet中选出第axis+1列内容等于value的数据行,且选出的数据不再包含第axis+1列的内容
retDataSet = []
for featVec in dataSet :
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]   #featVec[:axis]选去featVec中从第0到第axis列的内容
reduceFeatVec.extend(featVec[axis+1:])    #featVec[axis+1:]选去featVec中第axis+1列之后的内容
retDataSet.append(reduceFeatVec)
return retDataSet
########
##从输入的数据集中挑选取分类效果最好的特征
########
def chooseBestFeatureToSplit(dataSet):  ##选择最优的数据划分方式
numFeatures = len(dataSet[0]) - 1    ###计算每组数据中除去最后一个数据类型标签标示外总共有多少个特征,这些特征描述该组数据
baseEntropy = calcShannonEnt(dataSet)   #判断数据集合dataSet之中是否有多种类型的数据,即判断数据集的熵值是否为零
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):      #从第1列开始(即第一个特征,每一行各有一个特征值),每次选取一个特征去划分数据集合
featList =  [example[i] for example in dataSet]         ###i表示特征所在列的索引,featList存放该列的每一个具体的值
uniqueVals = set(featList)                      ####删除重复的值
newEntorpy = 0.0
for value in uniqueVals:
subDataSet = splitDateSet(dataSet, i, value)    ####从dataSet中筛选出第i列值为value的那些行,并把这些行划分为同一类
prob = len(subDataSet)/float(len(dataSet))     ###计算当前划分出的子数据集合所拥有的数据行数占总数据行的比值
newEntorpy += prob * calcShannonEnt(subDataSet)    ####用上一步的比值乘以子数据集合的熵,newEntorpy中存储每个子数据集合的熵的累加值
infoGain = baseEntropy - newEntorpy            ##由于熵的值为负数,为了比较大小,取整
if infoGain > bestInfoGain :           ####上面的步骤是从第i列下手,计算在此列中划分的各子数据集熵值的累加和
bestInfoGain = infoGain           ####使用冠军法选取最小的那个熵值下的特征列
bestFeature = i  ##返回最优分类特征的索引值
return bestFeature

<pre name="code" class="python">########
##定义这个方法的主要作用是用于处理特殊情况:当数据集已经处理了所有的属性,但是类标签依然不是唯一的,此时我们需要决定如何定义该叶节点,在这种情况下,我们通常会采用多数表决的方法决定该叶子结点的分类
########
def majorityCnt(classList): ###决定叶子节点的分类 classCount={} ###存储classList中每个类标签出现的频率 for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse=True) ##operator模块提供itemgetter函数用于获取对象的哪些维的数据或者哪些key对应的数据,参数就是索引号或key值.可以设置多个索引号或key值。要注意,operator.itemgetter函数获取的不是值,而是定义了一个函数,通过该函数作用到对象上才能获取值 ##operator.itemgetter(1)获取索引号为1的对象的内容 return sortedClassCount##########创建决策树########def createTree(dataSet,labels): classList = [ example[-1] for example in dataSet] ##获得数据的决策结果 if classList.count(classList[0] == len(classList)): ##如果当前数据集只有一组数据,结束返回 return classList[0] if len(dataSet[0]) == 1: ##当数据集中只有一组数据时说明此组数据一定是叶子节点 return majorityCnt(classList) bestFeat= chooseBestFeatureToSplit(dataSet) ##从当前数据集中选择出最优的分类特征的索引号 bestFeatLabel = labels[bestFeat] ##根据索引号提取具体的特征名 myTree = {bestFeatLabel:{}} ##根据获得的特征值创建树节点 del (labels[bestFeat]) ##将已经使用的特征从特征序列中删除 featValues = [example[bestFeat] for example in dataSet] ##获取该特征下的各种取值,当该特征作为父节点时,其不同的取值就是不同的子节点uniqueVals = set(featValues) ##去除重复的值,将不同的值存在一个临时的list中,该list的长度即子节点的数目 for value in uniqueVals: subLabels = labels[:] ##将去除已用特征后剩余的特征作为一个子标签序列 print bestFeat print value #splitDateSet(dataSet, bestFeat, value)从当前所用的数据集中删除第bestFeat列,值为value的那行数据,并返回余下的数据集myTree[bestFeatLabel][value] = createTree(splitDateSet(dataSet, bestFeat, value), subLabels) ##以当前所用特征为key,并将子树作为其value值 return myTree
  数据集的香农熵:利用计算公式计算,熵的值越大数据集中数据类型越多,熵值为零说明该数据集中就只有一个一类数据。【利用熵值来判断当前数据集中的数据类型是否单一,如果不单一则继续划分】 决策树的特点:标签都是非叶子结点且都有两个子结点,在python中以该标签为key的value值都是一个字典(dict);决策结果都是叶子结点。 决策树的实际存储模型:【利用决策树判断】假如使用的决策树为:决策树中的标签为:no surfacing,  flippers待判断的输入为:[1, 0]    【说明:输入[1,0]表示nosurfacing的value值为1,flippers的value值为0 ;其它输入依次类推断】在实际写代码使用决策树进行决策时一般会有三个输入:“树模型”、“用到的标签”、“标签符合情况”。其中标签符合情况中的元素个数与用到的标签的个数相同,且从前到后依次匹配。例如:用到的标签[a, b, c, d, e, f ...];标签符合情况[1, 0 , 0 , 0 , 1, 0 …]就表示与a:1;b:0;c:0;d:0;e:1;f:0。代码实现的过程中我们会从决策树的根结点开始,根据决策树上的实际标签值到“用到的标签”中找出该标签,之后再到“标签符合情况”中找到具体的符合情况。[具体代码如下]
########
##使用决策树进行分类
########
def classify(inputTree,featLabels,testVec):  #inputTree:决策树; featLabels:进行决策所用到的标签(该标签与决策树中的非叶子节点中的key相同); testVec:标签实际符合情况
firstStr = inputTree.keys()[0]  #获得根结点处的标签(key)
secondDict = inputTree[firstStr]  #获得根结点处的标签符合情况(value)
featIndex = featLabels.index(firstStr)  #查找当前标签(key)在实际输入标签中的位置
print featIndex
key = testVec[featIndex]  #根据实际标签,查找当前标签的实际符合情况
valueOfFeat = secondDict[key]  #根据符合情况获得子节点的内容,可能是决策结果也可能是一个以标签为key的字典(dict)
if isinstance(valueOfFeat, dict):  #判断子节点是叶子节点还是非叶子节点
classLabel = classify(valueOfFeat, featLabels, testVec)  #是一个dict继续决策
else: classLabel = valueOfFeat  #是个决策结果,返回该结果
return classLabel
-----------------------------------------------------------------------------------------------【画出具体的树型图】-----------------------------------------------------------------------------------
[假设输入数据为]文件名:lenses.txt
young<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
young<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>soft
young<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
young<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>hard
young<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
young<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>soft
young<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
young<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>hard
pre<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
pre<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>soft
pre<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
pre<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>hard
pre<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
pre<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>soft
pre<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
pre<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>no lenses
presbyopic<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
presbyopic<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>no lenses
presbyopic<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
presbyopic<span style="white-space:pre">	</span>myope<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>hard
presbyopic<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
presbyopic<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>no<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>soft
presbyopic<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>reduced<span style="white-space:pre">	</span>no lenses
presbyopic<span style="white-space:pre">	</span>hyper<span style="white-space:pre">	</span>yes<span style="white-space:pre">	</span>normal<span style="white-space:pre">	</span>no lenses
def plotNode(nodeTxt, centerPt, parentPt, nodeType):         ###执行实际的画图工作createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext = centerPt, textcoords= 'axes fraction',va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)## def createPlot():                     ###获得画图的地方#     fig = plt.figure(1, facecolor='white')  ###matplotlib.pyplot.figure    Creates a new figure.  num:an id for figure; facecolor:the background color#     fig.clf()  ###fig的类型是matplotlib.figure.Figure    fig.clf():Clear the figure#     createPlot.ax1 = plt.subplot(111, frameon=False)  ##subplot(m,n,p)==subplot(mnp) if m,n,p<10,作用是将多个figure放到一个平面中。其中m,n表示将一个figure切割成m行n列。p表示序号即第几块#     print type(createPlot.ax1)#     ### print type(createPlot.ax1)   打印的内容是<class 'matplotlib.axes.AxesSubplot'>#     plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)  ###执行实际的绘图工作#     plotNode('a left node', (0.8, 0.1), (0.3, 0.8), leafNode)#     plt.show()##########获得决策树的叶子节点个数########def getNumleafs(myTree):              ###获得树的叶子节点个数numLeafs = 0firstStr = myTree.keys()[0]     ###myTree.keys()的内容虽然是一个字符串,但是类型却是list;myTree.keys()[0]获取list中第一个索引代表的内容secondDict = myTree[firstStr]    ###secondDict是一个字典类型,secondDict.key()提取字典secondDict中所有的keyfor key in secondDict.keys():if type(secondDict[key]).__name__ =='dict':      ###一个字典类型对象的__name__属性的值为dict;一个list类型对象的__name__属性的值为list;该处用于判断一个secondDict中key所代表的value值是一个子字典还是字符串,如果是字符串则是叶子节点numLeafs += getNumleafs(secondDict[key])else:numLeafs += 1return numLeafs##########获得决策树的层数########def getTreeDepth(myTree):              ###获得树的层数maxDepth = 0thisDepth = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':      ###一个字典类型对象的__name__属性的值为dict;一个list类型对象的__name__属性的值为list;该处用于判断当前节点是否是叶节点thisDepth = 1+ getNumleafs(secondDict[key])else:thisDepth += 1if thisDepth > maxDepth:     ###使用冠军法获得最大层数maxDepth = thisDepthreturn maxDepthdef retrieveTree(i):              ###定义树的结构listOfTree = [ {'no surfacing': {0:'no', 1:{ 'flippers': {0:'no', 1:"yes"} } } },{'no surfacing': {0:'no', 1:{ 'flippers': {0: {'head': {0:'no', 1:'yes'}}, 1:'no'}}}}]return listOfTree[i]def plotMidText(cntrPt, parentPt, txtString):  #parentPt:父母节点的坐标; cntrPt:子节点的坐标; txtString:在父节点和子节点的连线上填充的内容xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  #xMid:要填充内容在x轴上的坐标yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  #yMid:要填充内容在y轴上的坐标createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):  ###myTree:要画出图形的树; parentPt:根结点的坐标   numLeafs = getNumleafs(myTree)  #获得树的叶节点数depth = getTreeDepth(myTree)  #获得树的层数firstStr = myTree.keys()[0]     #获得树的根节点cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)print cntrPtplotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodesplotTree(secondDict[key],cntrPt,str(key))        #recursionelse:   #it's a leaf node print the leaf nodeplotTree.xOff = plotTree.xOff + 1.0/plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])          ###class dict in module __builtin__ ; dict是一个字典类型的构造函数createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropsesplotTree.totalW = float(getNumleafs(inTree))         ###定义一个变量totalW,该变量属于plotTree,每次访问该变量时需要写成plotTree.totalWplotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;plotTree(inTree, (0.5,1.0), 'tree')plt.show()
if __name__ == "__main__":import treePlotterfr = open('lenses.txt')lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']lenses = [inst.strip().split('\t') for inst in fr.readlines()]lensesTree = createTree(lenses, lensesLabels)createPlot(lensesTree)

                                            
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  python 机器学习