您的位置:首页 > 其它

决策树(decision tree)

2018-02-27 20:05 197 查看
决策树的主要思想:计算数据集上的熵,划分数据集以获得最大的信息增益,再在子数据集上划分数据集以获得最大的信息增益,一直到子数据集下的所有实例都具有相同的类型,从而获得了一棵“决策树”。当需要使用决策树时,只需要根据新个体的各个特征,在决策树上从上往下走,最后到达的叶子节点就是新个体的分类。(看完下面的实现再来看这段话应该会更有感觉)

有关信息、熵和信息增益有详细资料可以看这儿。信息&熵&信息增益

优点:计算复杂度不高,输出结果容易理解,可以处理不相关特征数据。

缺点:可能会产生过度匹配的问题。

适用数据类型:数值型和标称型。

这里使用ID3算法划分数据集。

from math import log
import operator

def calShannonEnt(dataSet):
"""
计算给定数据集的香农熵
:param dataSet:
:return:
"""
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt

def createDataset():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels

def splitDataSet(dataSet, axis, value):
"""
按照给定特征划分数据集
:param dataSet:
:param axis:
:param value:
:return:
"""
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet

def chooseBestFeatureToSplit(dataSet):
"""
选择最好的数据集划分方式
:param dataSet:
:return:
"""
numFeatures = len(dataSet[0]) - 1
# 信息熵
baseEntropy = calShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
# 条件熵
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calShannonEnt(subDataSet)
# 信息增益 = 信息熵 - 条件熵
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature

def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),
key=operator.itemgetter(1),
reverse=True)
return sortedClassCount[0][0]

def createTree(dataSet, labels):
"""
创建树
:param dataSet:
:param labels:
:return:
"""
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del (labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(
splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree

def main():
myData, labels = createDataset()
myTree = createTree(myData, labels)
print(myTree)

if __name__ == '__main__':
main()


参考书籍:

哈林顿李锐. 机器学习实战 : Machine learning in action[M]. 人民邮电出版社, 2013.
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息