您的位置:首页 > 其它

机器学习之决策树生成和裁剪

2015-11-10 21:56 330 查看
决策树学习比较典型的有三种算法:ID3 C4.5 CART。

决策树是一种分类预测算法,通过训练样本建立的决策树,能够对未来样本进行分类。

决策树算法包括:建立决策树和裁剪决策树。裁剪决策树是为了减少过拟合带来的错误率。建立决策树的过程,是一种递归分级参考属性的过程,这个过程中会使用参考属性对目标属性的依赖关系。如下面例子,参考属性包括:有房、婚姻、收入。 目标属性:拖欠贷款。



ID3和C4.5

ID3和C4.5算法基本流程一致,区别是参考属性分级时,选择的标准不一样。

算法引入了信息论的信息熵,用于定量目标属性的确定性。样本集D中包含m类目标属性样本,每类样本的概率记为p(Ci),则目标属性信息熵定义:

H(D)=−∑i=1mpi(Ci)log2pi(Ci)

当使用参考属性进行分级时,会得到多个样本子集记为D0,D1,..Dk(∑iDi=D),子集Dj中关于目标属性的信息熵记为H(Dj),样本数为|Dj|,则分级后目标信息熵为:

H′(D)=∑i=0k|Di||D|H(Di)

举例来说,上面的训练样本集中拖欠贷款的信息熵为:−310log310−710log710

若使用有房参考属性进行分级,得到2个子集,有房子集样本数为3,没有房子集样本数为7。则分级后关于目标属性的信息熵:

H′(D)=310H(D有房)+710H(D没房)

ID3算法中,定义G(D)=H(D)−H′(D)为信息熵增益,表示经过分级后,对目标属性的判断把握性。使用不同的参考属性进行分级,得到的信息增益不一样。

使用不同的分级属性,得到的信息熵增益不同,ID3判定准则是,选择最大的信息增益对应的分级属性进行分级。

不过ID3的这个判定准则有一个缺陷,它总是倾向于属性值种类多的属性。例如上面的样本集,年收入用数字表达式,总类会有很多(数字是连续的)。因此,分级属性总是会倾向于这个参考属性。

为了解决这个问题,C4.5提出了使用信息增益率作为判定准则,增益率定义为:

GainRatio(D)=G(D)SplitInfo(D)

SplitInfo(D)=−∑j|Dj||D|log|Dj||D|

可以看出,分级后的子集数越多,SplitInfo越大,导致增益率越小。

若使用有房参考属性进行分级

SplitInfo(D)=−(03log03+33log33)−(47log47+37log37)

from math import log
#计算信息熵
def CalShannonEntropy(dataSet): #格式:参考属性1,参考属性2...,目标属性
sampleNum=len(dataSet)
samplecount={}
for data in dataSet:
currentfeature=data[-1]
if currentfeature not in samplecount.keys(): samplecount[currentfeature]=0
samplecount[currentfeature]+=1
entropy=0.0
for key in samplecount.keys():
p=float(samplecount.get(key))/sampleNum;
entropy-=p*log(p,2)
return entropy;
#选择最佳分级特征属性
def ChooseBestFeature(dataSet):
baseEntropy=CalShannonEntropy(dataSet) #第一次分级之前的信息熵
samplenum=len(dataSet)
labelnum=len(dataSet[0])-1

entropymax=0.0
bestLabelIndex=-1

for labelIndex in range(labelnum): #遍历所有的参考特征属性
values=set([example[labelIndex] for example in dataSet])
entsum=0.0
splitinfo=0.0
for value in values:
subdataSet=SplitDataSet(dataSet,labelIndex,value)   #分级后得到的子集D1,D2,...
p=float(len(subdataSet))/samplenum
entsum+=p*CalShannonEntropy(subdataSet) #计算分级后的信息熵
splitinfo-=p*log(p,2)                    #计算分级信息SplitInfo,ID3不用计算
infoGainRatio=(baseEntropy-entsum)/splitinfo #计算信息增益率,ID3计算增益就可以了

if infoGainRatio>entropymax:   #判断最大的增益率或增益
entropymax=infoGainRatio
bestLabelIndex=labelIndex
return bestLabelIndex  #返回最大增益或增益率的参考特征属性


CART

参考

除了ID3和C4.5,还有一种算法CART(classification and regression tree)。这是一种可以处理离散特征值和连续特征值的决策树,处理离散特征值使用分类决策树,处理连续特征值使用回归决策树。

CART的分级判定准则常用的是gini指数。gini指数和信息熵类似,gini指数越低,对目标属性判定越有把握。gini定义如下:

gini(D)=1−∑i=1kp2i

经过分级后,得到子样本集D0,...Dk,gini指数定义为:

gini′(D)=∑jk|Dj||D|gini(Dj)

若使用有房参考属性进行分级后,gini(D)=310(1−(03)2−(33)2)+710(1−(47)2−(37)2)

def CalGini(subdataSet):
#使用gini标准选择最佳分级特征属性
def ChooseBestFeature(dataSet):
samplenum=len(dataSet)
labelnum=len(dataSet[0])-1

ginimax=0.0
bestLabelIndex=-1

for labelIndex in range(labelnum): #遍历所有的参考特征属性
values=set([example[labelIndex] for example in dataSet])
ginisum=0.0
splitinfo=0.0
for value in values:
subdataSet=SplitDataSet(dataSet,labelIndex,value)   #分级后得到的子集D1,D2,...
p=float(len(subdataSet))/samplenum    #子集中样本数占比
ginisum+=p*CalShannonEntropy(subdataSet) #计算分级后的Gini
splitinfo-=p*log(p,2)                    #计算分级信息SplitInfo,ID3不用计算
infoGainRatio=(baseEntropy-entsum)/splitinfo #计算信息增益率,ID3计算增益就可以了

if infoGainRatio>entropymax:   #判断最大的增益率或增益
entropymax=infoGainRatio
bestLabelIndex=labelIndex
return bestLabelIndex  #返回最大增益或增益率的参考特征属性


决策树裁剪

为了防止过拟合,需要对决策树进行裁剪。裁剪分为事前裁剪和事后裁剪。事前裁剪发生在建立决策树时,通过判定规则(例如节点总数>),来决定是否进行新的分级。 事后裁剪发生在建立决策树后,通过判定规则进行树的修剪。常用的事后裁剪方法一般为CCP:代价复杂性修剪法

CCP对决策树中的每个非叶子节点定义了一个表面误差增益值α

α=R(node)−R(leaf)|Nleaf|−1

R(t)=minlr(|t|)|t|l+|t|r∗|t|l+|t|r|T|

例如:



设样本集样本总数为100

R(T4)=7100,R(T8)=2100,R(T9)=0100,R(T7)=3100,R(T6)=3100

αT4=R(T4)−(R(T7)+R(T8)+R(T9))3−1=2200

αT6=R(T6)−(R(T8)+R(T9))3−1=1100

实验:

训练样本



二值图标记



建立决策树,完成二值图。



import numpy as np
from skimage import io
from skimage import color
from math import log
import operator
import matplotlib.pyplot as plt

def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: #the the number of unique elements and their occurance
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2) #log base 2
return shannonEnt

def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet

def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures):        #iterate over all the features
featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
uniqueVals = set(featList)       #get a set of unique values
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
if (infoGain > bestInfoGain):       #compare this to the best gain so far
bestInfoGain = infoGain         #if better than current best, set to best
bestFeature = i
return bestFeature                      #returns an integer

def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]

def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]#stop splitting when all of the classes are equal
if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
#Tree Format
#{FeatureLabel1: {Value1_1:{FeatureLabel2:{Value2_1:{},Value2_2:{},Value2_3:{}}},Value1_2:{},Value1_3:{}}}
#

#验证
def classify(inputTree,featLabels,testVec):
firstStr = inputTree.keys()[0]  #best label
secondDict = inputTree[firstStr]  #dict
featIndex = featLabels.index(firstStr) #best feature index
key = testVec[featIndex]                #testValue
if key not in secondDict:
return -1
valueOfFeat = secondDict[key]           #

if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else:
classLabel = valueOfFeat

return classLabel

plt.figure(figsize(20,10))
test_img=io.imread('E:\BaiduYunDownload\ML\project\\flower_test.png')
plt.subplot(221)
plt.imshow(test_img)
mask_img=io.imread('E:\BaiduYunDownload\ML\project\\flower_mask.png')
plt.subplot(222)
plt.imshow(mask_img,cmap=plt.cm.gray)

img_h,img_w,dim=test_img.shape
test_data=floor(test_img.reshape(img_w*img_h,dim)/32)
mask_data=mask_img.reshape(img_w*img_h)

n_data=test_data[mask_data==0]
p_data=test_data[mask_data==255]
n_data=np.hstack((n_data,zeros((n_data.shape[0],1))))
p_data=np.hstack((p_data,ones((p_data.shape[0],1))))
dataSet=np.vstack((n_data,p_data))
labels=['R','G','B']

print "Create Tree"
myTree=createTree(dataSet.tolist(),labels)

print "testing"
result=np.arange(img_h*img_w)
test=test_data.reshape((img_h*img_w,3))
index=0
lastresult=0
labels=['R','G','B']
for t in test:
classLabel=classify(myTree,labels,t)
if classLabel==0:
result[index]=0
lastresult=0
elif classLabel==1:
result[index]=255
lastresult=255
else:
result[index]=lastresult
index+=1

result=result.reshape((img_h,img_w))
plt.subplot(223)
plt.imshow(result,cmap=plt.cm.gray)
plt.subplots
print "finished"


若是使用CART算法,则结果变为

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  机器学习