您的位置:首页 > 编程语言 > PHP开发

ID3decision tree-ID3决策树实现

2016-11-03 14:12 417 查看
ID3决策树是根据信息增益来划分节点的。

一、创建ID3决策树:

<1>创建简单的数据集:

根据下图创建数据集:



  图表的意思是:表中5个海洋动物,特征包括两个:1、不浮出水面是否可以生存,2、是否有脚蹼。我们可以将这些动物分成两类: 鱼类和非鱼类。

#create data set and lebels
def CreateDataset():
dataset = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataset, labels

<2>计算数据集的香农熵(Shannon Entropy)

#Function to calculate the Shannon entropy of a dataset
def CalculateShannonEnt(dataset):
entropy_num = len(dataset)
label_count = {} #计算每个标签对应的个数,是一个集合
for feature_vector in dataset:#feature_vector每一次取dataset的一个列表元素(dataset的元素也是一个列表)
current_label = feature_vector[-1]#对应的是当前海洋生物的类别,即标签
if current_label not in label_count.keys():#判断当前的标签是否在label_count中
label_count[current_label] = 0        #为了统一操作接下来的+1
label_count[current_label] += 1
shannonent = 0.0
for key in label_count:#计算香农熵
prob = float(label_count[key])/entropy_num
shannonent -= prob*log(prob, 2)
return shannonent

测试:

dataset, labels = CreateDataset()
print"Shannon Entropy = ", CalculateShannonEnt(dataset)

Shannon Entropy =  0.970950594455


<3>根据给定的特征来划分数据集

#Dataset splitting on a given feature
#feature_axis表示给定特征值的所在的列,feature_value就是feature_axis所对应的元素
#如:dataset = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
#split = SplitDataSet(dataset, 0, 1), split = [[1, 'yes'], [1, 'yes'], [0, 'no']]
def SplitDataSet(dataset, feature_axis, feature_value):
split_set = []
for featvec in dataset:
if featvec[feature_axis] == feature_value:
reduced_feature_vector = featvec[:feature_axis]#reduced_feature_vector去除当前特征值,并保存剩余的列值
reduced_feature_vector.extend(featvec[feature_axis+1:])#extend和append接下来会单独说明
split_set.append(reduced_feature_vector)
return split_set
  之所以要重新申请一个列表split_set,是因为Python函数参数传递的是引用,因此,如果在函数内部修改原列表dataset,会导致dataset的改变,因此需要重新申请一个列表。

  extend和append两个函数的区别:

a = [1, 2, 3]
b = [4, 5, 6]
a.extend(b)
print a
[1, 2, 3, 4, 5, 6]

a = [1, 2, 3]
a.append(b)
print a
[1, 2, 3, [4, 5, 6]]
  由此可知a.extend(b)得到一个包含a和b所有元素的列表,而a.append(b)得到第四个元素,但第四个元素也是一个列表。

测试:

dataset, labels = CreateDataset()
splitset = SplitDataSet(dataset, 0, 1)
print splitset
结果:
[[1, 'yes'], [1, 'yes'], [0, 'no']]
<4>选择最好的数据集划分方式

# Choosing the best feature to split on
#根据信息增益来选择最佳特征值
def ChooseBestFeature(dataset):
feature_number = len(dataset[0])-1#计算数据集中特征的个数
base_entropy =  CalculateShannonEnt(dataset)#计算数据集的熵
best_info_gain = 0.0#最大的信息增益
best_feature = -1#最佳特征
for i in xrange(feature_number):
feat_list = [example[i] for example in dataset]#取第i列的特征赋值给feat_list
unique_value = set(feat_list)#集合唯一性,去除重复的值
temp_entropy = 0.0
for value in unique_value:
subset = SplitDataSet(dataset, i, value)#根据第i个的特征值去划分数据集
prob = len(subset)/float(len(dataset))#计算分割后的数据集占总的数据的比
temp_entropy += prob * CalculateShannonEnt(subset)
infogain = base_entropy - temp_entropy
#print"第 %d 个特征的信息增益"%i, infogain#测试所选的是否信息增益为最大
if infogain > best_info_gain:
best_info_gain = infogain
best_feature = i;
return best_feature


测试:

dataset, labels = CreateDataset()
print ChooseBestFeature(dataset)
结果:

0
  这个结果表明,0是最好的用于划分数据集的特征,那么这个选择是否对的呢?将代码中“#print"第 %d 个特征的信息增益"%i, infogain”,这句话去掉注释,既可以得出:
第 0 个特征的信息增益 0.419973094022
第 1 个特征的信息增益 0.170950594455
<5>选取数目最多的标签

#对标签投票,找出数目最多的标签
def Majoritycount(classlist):
classcount = {}
for vote in classlist:
if vote not in classcount.keys():
classcount[vote] = 0
classcount[vote] += 1
#按照每个标签的数目大小,进行从大到小排序,
sortedclasscount = sorted(classcount.iteritems, key = operator.itemgetter(1), reverse = True)
#返回数目最多的标签
return sortedclasscount[0][0]
<6>创建树的函数

#Tree-building code
def CreateTree(dataset, labels):
classlist = [example[-1] for example in dataset]#取出所有的类别
if classlist.count(classlist[0]) == len(dataset):#如果所有的样本属于同一个类别,结束递归
return classlist[0]
if len(classlist[0]) == 1:
return Majoritycount(classlist)
bestfeat = ChooseBestFeature(dataset)
best_label = labels[bestfeat]

my_tree = {best_label:{}}
del(labels[bestfeat])
feat_values = [example[bestfeat] for example in dataset]#选取bestfeat列所对应的特征值
unique_value = set(feat_values)
for value in unique_value:
sublabels = labels[:]#sublabels存储的是除去最佳特征的labels
my_tree[best_label][value] = CreateTree(SplitDataSet(dataset, bestfeat, value), sublabels)
return my_tree
测试:

dataset, labels = CreateDataset()
print CreateTree(dataset, labels)
结果:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
转换成图形即:



二、利用Python画图

# -*- coding: utf-8 -*-
"""
Created on Tue Nov 01 11:00:37 2016

@author: G
"""
import matplotlib.pyplot as plt

decision_node = dict(boxstyle = 'sawtooth', fc = "0.8")
leaf_node = dict(boxstyle = 'round4', fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

#读取叶子结点个数,为的就是画树的时候,计算节点和箭头的位置
def GetLeafsNumber(mytree):
leafs_number = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
leafs_number += GetLeafsNumber(first_value[key])
else:
leafs_number += 1
return leafs_number

#读取树的深度
def GetTreeDepth(mytree):
max_depth = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
this_depth = GetTreeDepth( first_value[key])+1
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth

#画节点,属性名字
def PlotNode(node_txt, center_pt, parent_pt, node_type):
CreatePlot.ax1.annotate(node_txt, xy = parent_pt, xycoords = "axes fraction",
xytext = center_pt, textcoords = "axes fraction",
va = "center", ha = "center", bbox = node_type, arrowprops = arrow_args)
#在剪头中间(即树枝上)添加权值(属性值)
def PlotMidText(center_pt, parent_pt, txt_string):
xmid = (parent_pt[0] - center_pt[0])/2.0 + center_pt[0]*1.03#此处乘1.03目的就是不让权值出现在树枝上
ymid = (parent_pt[1] - center_pt[1])/2.0 + center_pt[1]
CreatePlot.ax1.text(xmid, ymid, txt_string, va = "center", ha = "center", rotation = 30)

#画树
def PlotTree(mytree, parent_pt, nodetxt):
leafs_number = GetLeafsNumber(mytree)
first_feature = mytree.keys()[0]
#接下来就是计算center_pt,就是节点的中间位置
center_pt = (PlotTree.xOff + (1.0 + float(leafs_number))/2.0/PlotTree.totalW, PlotTree.yOff)
PlotMidText(center_pt, parent_pt, nodetxt)
PlotNode(first_feature, center_pt, parent_pt, decision_node)

#第一个字典的key对应的value,存储的也就是该节点的子树
first_value = mytree[first_feature]
PlotTree.yOff = PlotTree.yOff - 1.0/PlotTree.totalD
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
PlotTree(first_value[key], center_pt, str(key))
else:
PlotTree.xOff = PlotTree.xOff + 1.0/PlotTree.totalW
PlotNode(first_value[key], (PlotTree.xOff, PlotTree.yOff), center_pt, leaf_node)
PlotMidText((PlotTree.xOff, PlotTree.yOff), center_pt, str(key))
PlotTree.yOff = PlotTree.yOff + 1.0/PlotTree.totalD

def CreatePlot(input_tree):
fig = plt.figure(1, facecolor = "white")
fig.clf()
axprops = dict(xticks = [], yticks = [])
CreatePlot.ax1 = plt.subplot(111, frameon = False, **axprops)
PlotTree.totalW = float(GetLeafsNumber(input_tree))
PlotTree.totalD = float(GetTreeDepth(input_tree))
PlotTree.xOff = -0.5/PlotTree.totalW
PlotTree.yOff = 1.0
PlotTree(input_tree, (0.5, 1.0), '')
plt.show()

'''
def CreatePlot():
fig = plt.figure(1, facecolor = "white")
fig.clf()
CreatePlot.ax1 = plt.subplot(111, frameon = False)
PlotNode("a decision node", (0.5, 0.1), (0.1, 0.5), decision_node)
PlotNode("a leaf node", (0.8, 0.1), (0.3, 0.8), leaf_node)
plt.show()
'''

def RetrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]


三、完整代码
<1>DecisionTrees.py
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 31 19:33:15 2016

@author: G
"""

from math import log
import operator
import DecisionTreePlotter as DTP

#create data set and lebels
def CreateDataSet():
dataset = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataset, labels

#Function to calculate the Shannon entropy of a dataset def CalculateShannonEnt(dataset): entropy_num = len(dataset) label_count = {} #计算每个标签对应的个数,是一个集合 for feature_vector in dataset:#feature_vector每一次取dataset的一个列表元素(dataset的元素也是一个列表) current_label = feature_vector[-1]#对应的是当前海洋生物的类别,即标签 if current_label not in label_count.keys():#判断当前的标签是否在label_count中 label_count[current_label] = 0 #为了统一操作接下来的+1 label_count[current_label] += 1 shannonent = 0.0 for key in label_count:#计算香农熵 prob = float(label_count[key])/entropy_num shannonent -= prob*log(prob, 2) return shannonent
#Dataset splitting on a given feature #feature_axis表示给定特征值的所在的列,feature_value就是feature_axis所对应的元素 #如:dataset = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] #split = SplitDataSet(dataset, 0, 1), split = [[1, 'yes'], [1, 'yes'], [0, 'no']] def SplitDataSet(dataset, feature_axis, feature_value): split_set = [] for featvec in dataset: if featvec[feature_axis] == feature_value: reduced_feature_vector = featvec[:feature_axis]#reduced_feature_vector去除当前特征值,并保存剩余的列值 reduced_feature_vector.extend(featvec[feature_axis+1:])#extend和append接下来会单独说明 split_set.append(reduced_feature_vector) return split_set# Choosing the best feature to split on
#根据信息增益来选择最佳特征值
def ChooseBestFeature(dataset):
feature_number = len(dataset[0])-1#计算数据集中特征的个数
base_entropy = CalculateShannonEnt(dataset)#计算数据集的熵
best_info_gain = 0.0#最大的信息增益
best_feature = -1#最佳特征
for i in xrange(feature_number):
feat_list = [example[i] for example in dataset]#取第i列的特征赋值给feat_list
unique_value = set(feat_list)#集合唯一性,去除重复的值
temp_entropy = 0.0
for value in unique_value:
subset = SplitDataSet(dataset, i, value)#根据第i个的特征值去划分数据集
prob = len(subset)/float(len(dataset))#计算分割后的数据集占总的数据的比
temp_entropy += prob * CalculateShannonEnt(subset)
infogain = base_entropy - temp_entropy
#print"第 %d 个特征的信息增益"%i, infogain
if infogain > best_info_gain:
best_info_gain = infogain
best_feature = i;
return best_feature#返回的是最佳特征的id

#对标签投票,找出classlist中数目最多的标签
def Majoritycount(classlist):
classcount = {}
for vote in classlist:
if vote not in classcount.keys():
classcount[vote] = 0
classcount[vote] += 1
#按照每个标签的数目大小,进行从大到小排序,
sortedclasscount = sorted(classcount.iteritems, key = operator.itemgetter(1), reverse = True)
#返回数目最多的标签
return sortedclasscount[0][0]

#Tree-building code def CreateTree(dataset, labels): classlist = [example[-1] for example in dataset]#取出所有的类别 if classlist.count(classlist[0]) == len(dataset):#如果所有的样本属于同一个类别,结束递归 return classlist[0] if len(classlist[0]) == 1: return Majoritycount(classlist) bestfeat = ChooseBestFeature(dataset) best_label = labels[bestfeat] my_tree = {best_label:{}} del(labels[bestfeat]) feat_values = [example[bestfeat] for example in dataset]#选取bestfeat列所对应的特征值 unique_value = set(feat_values) for value in unique_value: sublabels = labels[:]#sublabels存储的是除去最佳特征的labels my_tree[best_label][value] = CreateTree(SplitDataSet(dataset, bestfeat, value), sublabels) return my_tree
#对于给定的数据实现分类
def Classify(input_tree, feat_labels, test_vector):
first_feature = input_tree.keys()[0]
first_value = input_tree[first_feature]
feature_index = feat_labels.index(first_feature)
for key in first_value.keys():
if test_vector[feature_index] == key:
if type(first_value[key]).__name__ == 'dict':#此句成立说明有子树
class_label = Classify(first_value[key], feat_labels, test_vector)
else:
class_label = first_value[key]
return class_label

#为了每次分类的时候,不必重新建树,从而节省时间
def StoreTree(input_tree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(input_tree, fw)
fw.close()

def ReadTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)

#测试样例
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lenses_labels = ['age', 'prescript', 'astigmatic1', 'tearRate1']
lense_tree = CreateTree(lenses, lenses_labels)

print lense_tree
DTP.CreatePlot(lense_tree)

''''
dataset, labels = CreateDataSet()
print labels
mytree = DTP.RetrieveTree(0)

print Classify(mytree, labels, [1, 0])
'''

<2DecisionTreePlotter.py>

# -*- coding: utf-8 -*-
"""
Created on Tue Nov 01 11:00:37 2016

@author: G
"""
import matplotlib.pyplot as plt

decision_node = dict(boxstyle = 'sawtooth', fc = "0.8")
leaf_node = dict(boxstyle = 'round4', fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

#读取叶子结点个数,为的就是画树的时候,计算节点和箭头的位置
def GetLeafsNumber(mytree):
leafs_number = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
leafs_number += GetLeafsNumber(first_value[key])
else:
leafs_number += 1
return leafs_number

#读取树的深度
def GetTreeDepth(mytree):
max_depth = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
this_depth = GetTreeDepth( first_value[key])+1
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth

#画节点,属性名字
def PlotNode(node_txt, center_pt, parent_pt, node_type):
CreatePlot.ax1.annotate(node_txt, xy = parent_pt, xycoords = "axes fraction",
xytext = center_pt, textcoords = "axes fraction",
va = "center", ha = "center", bbox = node_type, arrowprops = arrow_args)
#在剪头中间(即树枝上)添加权值(属性值)
def PlotMidText(center_pt, parent_pt, txt_string):
xmid = (parent_pt[0] - center_pt[0])/2.0 + center_pt[0]*1.03#此处乘1.03目的就是不让权值出现在树枝上
ymid = (parent_pt[1] - center_pt[1])/2.0 + center_pt[1]
CreatePlot.ax1.text(xmid, ymid, txt_string, va = "center", ha = "center", rotation = 30)

#画树
def PlotTree(mytree, parent_pt, nodetxt):
leafs_number = GetLeafsNumber(mytree)
first_feature = mytree.keys()[0]
#接下来就是计算center_pt,就是节点的中间位置
center_pt = (PlotTree.xOff + (1.0 + float(leafs_number))/2.0/PlotTree.totalW, PlotTree.yOff)
PlotMidText(center_pt, parent_pt, nodetxt)
PlotNode(first_feature, center_pt, parent_pt, decision_node)

#第一个字典的key对应的value,存储的也就是该节点的子树
first_value = mytree[first_feature]
PlotTree.yOff = PlotTree.yOff - 1.0/PlotTree.totalD
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
PlotTree(first_value[key], center_pt, str(key))
else:
PlotTree.xOff = PlotTree.xOff + 1.0/PlotTree.totalW
PlotNode(first_value[key], (PlotTree.xOff, PlotTree.yOff), center_pt, leaf_node)
PlotMidText((PlotTree.xOff, PlotTree.yOff), center_pt, str(key))
PlotTree.yOff = PlotTree.yOff + 1.0/PlotTree.totalD

def CreatePlot(input_tree):
fig = plt.figure(1, facecolor = "white")
fig.clf()
axprops = dict(xticks = [], yticks = [])
CreatePlot.ax1 = plt.subplot(111, frameon = False, **axprops)
PlotTree.totalW = float(GetLeafsNumber(input_tree))
PlotTree.totalD = float(GetTreeDepth(input_tree))
PlotTree.xOff = -0.5/PlotTree.totalW
PlotTree.yOff = 1.0
PlotTree(input_tree, (0.5, 1.0), '')
plt.show()

'''
def CreatePlot():
fig = plt.figure(1, facecolor = "white")
fig.clf()
CreatePlot.ax1 = plt.subplot(111, frameon = False)
PlotNode("a decision node", (0.5, 0.1), (0.1, 0.5), decision_node)
PlotNode("a leaf node", (0.8, 0.1), (0.3, 0.8), leaf_node)
plt.show()
'''

def RetrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]

<4>测试结果:
lense_tree = {'tearRate1': {'reduced': 'no lenses', 'normal': {'astigmatic1': {'yes':
{'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}},
'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript':
{'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}



实验数据:http://download.csdn.net/detail/hearthougan/9664933

ID3决策树有很多的不足,比如容易导致过拟合现象,为此需要一些剪枝策略,具体请参考周志华老师的《机器学习》:

http://download.csdn.net/detail/hearthougan/9652865
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  决策树 ID3 matplotlib