ID3decision tree-ID3决策树实现
2016-11-03 14:12
417 查看
ID3决策树是根据信息增益来划分节点的。
一、创建ID3决策树:
<1>创建简单的数据集:
根据下图创建数据集:
图表的意思是:表中5个海洋动物,特征包括两个:1、不浮出水面是否可以生存,2、是否有脚蹼。我们可以将这些动物分成两类: 鱼类和非鱼类。
<2>计算数据集的香农熵(Shannon Entropy)
测试:
<3>根据给定的特征来划分数据集
extend和append两个函数的区别:
测试:
测试:
二、利用Python画图
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 01 11:00:37 2016
@author: G
"""
import matplotlib.pyplot as plt
decision_node = dict(boxstyle = 'sawtooth', fc = "0.8")
leaf_node = dict(boxstyle = 'round4', fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
#读取叶子结点个数,为的就是画树的时候,计算节点和箭头的位置
def GetLeafsNumber(mytree):
leafs_number = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
leafs_number += GetLeafsNumber(first_value[key])
else:
leafs_number += 1
return leafs_number
#读取树的深度
def GetTreeDepth(mytree):
max_depth = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
this_depth = GetTreeDepth( first_value[key])+1
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth
#画节点,属性名字
def PlotNode(node_txt, center_pt, parent_pt, node_type):
CreatePlot.ax1.annotate(node_txt, xy = parent_pt, xycoords = "axes fraction",
xytext = center_pt, textcoords = "axes fraction",
va = "center", ha = "center", bbox = node_type, arrowprops = arrow_args)
#在剪头中间(即树枝上)添加权值(属性值)
def PlotMidText(center_pt, parent_pt, txt_string):
xmid = (parent_pt[0] - center_pt[0])/2.0 + center_pt[0]*1.03#此处乘1.03目的就是不让权值出现在树枝上
ymid = (parent_pt[1] - center_pt[1])/2.0 + center_pt[1]
CreatePlot.ax1.text(xmid, ymid, txt_string, va = "center", ha = "center", rotation = 30)
#画树
def PlotTree(mytree, parent_pt, nodetxt):
leafs_number = GetLeafsNumber(mytree)
first_feature = mytree.keys()[0]
#接下来就是计算center_pt,就是节点的中间位置
center_pt = (PlotTree.xOff + (1.0 + float(leafs_number))/2.0/PlotTree.totalW, PlotTree.yOff)
PlotMidText(center_pt, parent_pt, nodetxt)
PlotNode(first_feature, center_pt, parent_pt, decision_node)
#第一个字典的key对应的value,存储的也就是该节点的子树
first_value = mytree[first_feature]
PlotTree.yOff = PlotTree.yOff - 1.0/PlotTree.totalD
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
PlotTree(first_value[key], center_pt, str(key))
else:
PlotTree.xOff = PlotTree.xOff + 1.0/PlotTree.totalW
PlotNode(first_value[key], (PlotTree.xOff, PlotTree.yOff), center_pt, leaf_node)
PlotMidText((PlotTree.xOff, PlotTree.yOff), center_pt, str(key))
PlotTree.yOff = PlotTree.yOff + 1.0/PlotTree.totalD
def CreatePlot(input_tree):
fig = plt.figure(1, facecolor = "white")
fig.clf()
axprops = dict(xticks = [], yticks = [])
CreatePlot.ax1 = plt.subplot(111, frameon = False, **axprops)
PlotTree.totalW = float(GetLeafsNumber(input_tree))
PlotTree.totalD = float(GetTreeDepth(input_tree))
PlotTree.xOff = -0.5/PlotTree.totalW
PlotTree.yOff = 1.0
PlotTree(input_tree, (0.5, 1.0), '')
plt.show()
'''
def CreatePlot():
fig = plt.figure(1, facecolor = "white")
fig.clf()
CreatePlot.ax1 = plt.subplot(111, frameon = False)
PlotNode("a decision node", (0.5, 0.1), (0.1, 0.5), decision_node)
PlotNode("a leaf node", (0.8, 0.1), (0.3, 0.8), leaf_node)
plt.show()
'''
def RetrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
三、完整代码
<1>DecisionTrees.py
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 31 19:33:15 2016
@author: G
"""
from math import log
import operator
import DecisionTreePlotter as DTP
#create data set and lebels
def CreateDataSet():
dataset = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataset, labels
#Function to calculate the Shannon entropy of a dataset def CalculateShannonEnt(dataset): entropy_num = len(dataset) label_count = {} #计算每个标签对应的个数,是一个集合 for feature_vector in dataset:#feature_vector每一次取dataset的一个列表元素(dataset的元素也是一个列表) current_label = feature_vector[-1]#对应的是当前海洋生物的类别,即标签 if current_label not in label_count.keys():#判断当前的标签是否在label_count中 label_count[current_label] = 0 #为了统一操作接下来的+1 label_count[current_label] += 1 shannonent = 0.0 for key in label_count:#计算香农熵 prob = float(label_count[key])/entropy_num shannonent -= prob*log(prob, 2) return shannonent
#Dataset splitting on a given feature #feature_axis表示给定特征值的所在的列,feature_value就是feature_axis所对应的元素 #如:dataset = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] #split = SplitDataSet(dataset, 0, 1), split = [[1, 'yes'], [1, 'yes'], [0, 'no']] def SplitDataSet(dataset, feature_axis, feature_value): split_set = [] for featvec in dataset: if featvec[feature_axis] == feature_value: reduced_feature_vector = featvec[:feature_axis]#reduced_feature_vector去除当前特征值,并保存剩余的列值 reduced_feature_vector.extend(featvec[feature_axis+1:])#extend和append接下来会单独说明 split_set.append(reduced_feature_vector) return split_set# Choosing the best feature to split on
#根据信息增益来选择最佳特征值
def ChooseBestFeature(dataset):
feature_number = len(dataset[0])-1#计算数据集中特征的个数
base_entropy = CalculateShannonEnt(dataset)#计算数据集的熵
best_info_gain = 0.0#最大的信息增益
best_feature = -1#最佳特征
for i in xrange(feature_number):
feat_list = [example[i] for example in dataset]#取第i列的特征赋值给feat_list
unique_value = set(feat_list)#集合唯一性,去除重复的值
temp_entropy = 0.0
for value in unique_value:
subset = SplitDataSet(dataset, i, value)#根据第i个的特征值去划分数据集
prob = len(subset)/float(len(dataset))#计算分割后的数据集占总的数据的比
temp_entropy += prob * CalculateShannonEnt(subset)
infogain = base_entropy - temp_entropy
#print"第 %d 个特征的信息增益"%i, infogain
if infogain > best_info_gain:
best_info_gain = infogain
best_feature = i;
return best_feature#返回的是最佳特征的id
#对标签投票,找出classlist中数目最多的标签
def Majoritycount(classlist):
classcount = {}
for vote in classlist:
if vote not in classcount.keys():
classcount[vote] = 0
classcount[vote] += 1
#按照每个标签的数目大小,进行从大到小排序,
sortedclasscount = sorted(classcount.iteritems, key = operator.itemgetter(1), reverse = True)
#返回数目最多的标签
return sortedclasscount[0][0]
#Tree-building code def CreateTree(dataset, labels): classlist = [example[-1] for example in dataset]#取出所有的类别 if classlist.count(classlist[0]) == len(dataset):#如果所有的样本属于同一个类别,结束递归 return classlist[0] if len(classlist[0]) == 1: return Majoritycount(classlist) bestfeat = ChooseBestFeature(dataset) best_label = labels[bestfeat] my_tree = {best_label:{}} del(labels[bestfeat]) feat_values = [example[bestfeat] for example in dataset]#选取bestfeat列所对应的特征值 unique_value = set(feat_values) for value in unique_value: sublabels = labels[:]#sublabels存储的是除去最佳特征的labels my_tree[best_label][value] = CreateTree(SplitDataSet(dataset, bestfeat, value), sublabels) return my_tree
#对于给定的数据实现分类
def Classify(input_tree, feat_labels, test_vector):
first_feature = input_tree.keys()[0]
first_value = input_tree[first_feature]
feature_index = feat_labels.index(first_feature)
for key in first_value.keys():
if test_vector[feature_index] == key:
if type(first_value[key]).__name__ == 'dict':#此句成立说明有子树
class_label = Classify(first_value[key], feat_labels, test_vector)
else:
class_label = first_value[key]
return class_label
#为了每次分类的时候,不必重新建树,从而节省时间
def StoreTree(input_tree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(input_tree, fw)
fw.close()
def ReadTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
#测试样例
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lenses_labels = ['age', 'prescript', 'astigmatic1', 'tearRate1']
lense_tree = CreateTree(lenses, lenses_labels)
print lense_tree
DTP.CreatePlot(lense_tree)
''''
dataset, labels = CreateDataSet()
print labels
mytree = DTP.RetrieveTree(0)
print Classify(mytree, labels, [1, 0])
'''
<2DecisionTreePlotter.py>
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 01 11:00:37 2016
@author: G
"""
import matplotlib.pyplot as plt
decision_node = dict(boxstyle = 'sawtooth', fc = "0.8")
leaf_node = dict(boxstyle = 'round4', fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
#读取叶子结点个数,为的就是画树的时候,计算节点和箭头的位置
def GetLeafsNumber(mytree):
leafs_number = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
leafs_number += GetLeafsNumber(first_value[key])
else:
leafs_number += 1
return leafs_number
#读取树的深度
def GetTreeDepth(mytree):
max_depth = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
this_depth = GetTreeDepth( first_value[key])+1
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth
#画节点,属性名字
def PlotNode(node_txt, center_pt, parent_pt, node_type):
CreatePlot.ax1.annotate(node_txt, xy = parent_pt, xycoords = "axes fraction",
xytext = center_pt, textcoords = "axes fraction",
va = "center", ha = "center", bbox = node_type, arrowprops = arrow_args)
#在剪头中间(即树枝上)添加权值(属性值)
def PlotMidText(center_pt, parent_pt, txt_string):
xmid = (parent_pt[0] - center_pt[0])/2.0 + center_pt[0]*1.03#此处乘1.03目的就是不让权值出现在树枝上
ymid = (parent_pt[1] - center_pt[1])/2.0 + center_pt[1]
CreatePlot.ax1.text(xmid, ymid, txt_string, va = "center", ha = "center", rotation = 30)
#画树
def PlotTree(mytree, parent_pt, nodetxt):
leafs_number = GetLeafsNumber(mytree)
first_feature = mytree.keys()[0]
#接下来就是计算center_pt,就是节点的中间位置
center_pt = (PlotTree.xOff + (1.0 + float(leafs_number))/2.0/PlotTree.totalW, PlotTree.yOff)
PlotMidText(center_pt, parent_pt, nodetxt)
PlotNode(first_feature, center_pt, parent_pt, decision_node)
#第一个字典的key对应的value,存储的也就是该节点的子树
first_value = mytree[first_feature]
PlotTree.yOff = PlotTree.yOff - 1.0/PlotTree.totalD
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
PlotTree(first_value[key], center_pt, str(key))
else:
PlotTree.xOff = PlotTree.xOff + 1.0/PlotTree.totalW
PlotNode(first_value[key], (PlotTree.xOff, PlotTree.yOff), center_pt, leaf_node)
PlotMidText((PlotTree.xOff, PlotTree.yOff), center_pt, str(key))
PlotTree.yOff = PlotTree.yOff + 1.0/PlotTree.totalD
def CreatePlot(input_tree):
fig = plt.figure(1, facecolor = "white")
fig.clf()
axprops = dict(xticks = [], yticks = [])
CreatePlot.ax1 = plt.subplot(111, frameon = False, **axprops)
PlotTree.totalW = float(GetLeafsNumber(input_tree))
PlotTree.totalD = float(GetTreeDepth(input_tree))
PlotTree.xOff = -0.5/PlotTree.totalW
PlotTree.yOff = 1.0
PlotTree(input_tree, (0.5, 1.0), '')
plt.show()
'''
def CreatePlot():
fig = plt.figure(1, facecolor = "white")
fig.clf()
CreatePlot.ax1 = plt.subplot(111, frameon = False)
PlotNode("a decision node", (0.5, 0.1), (0.1, 0.5), decision_node)
PlotNode("a leaf node", (0.8, 0.1), (0.3, 0.8), leaf_node)
plt.show()
'''
def RetrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
<4>测试结果:
lense_tree = {'tearRate1': {'reduced': 'no lenses', 'normal': {'astigmatic1': {'yes':
{'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}},
'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript':
{'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}
实验数据:http://download.csdn.net/detail/hearthougan/9664933
ID3决策树有很多的不足,比如容易导致过拟合现象,为此需要一些剪枝策略,具体请参考周志华老师的《机器学习》:
http://download.csdn.net/detail/hearthougan/9652865
一、创建ID3决策树:
<1>创建简单的数据集:
根据下图创建数据集:
图表的意思是:表中5个海洋动物,特征包括两个:1、不浮出水面是否可以生存,2、是否有脚蹼。我们可以将这些动物分成两类: 鱼类和非鱼类。
#create data set and lebels def CreateDataset(): dataset = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing', 'flippers'] return dataset, labels
<2>计算数据集的香农熵(Shannon Entropy)
#Function to calculate the Shannon entropy of a dataset def CalculateShannonEnt(dataset): entropy_num = len(dataset) label_count = {} #计算每个标签对应的个数,是一个集合 for feature_vector in dataset:#feature_vector每一次取dataset的一个列表元素(dataset的元素也是一个列表) current_label = feature_vector[-1]#对应的是当前海洋生物的类别,即标签 if current_label not in label_count.keys():#判断当前的标签是否在label_count中 label_count[current_label] = 0 #为了统一操作接下来的+1 label_count[current_label] += 1 shannonent = 0.0 for key in label_count:#计算香农熵 prob = float(label_count[key])/entropy_num shannonent -= prob*log(prob, 2) return shannonent
测试:
dataset, labels = CreateDataset() print"Shannon Entropy = ", CalculateShannonEnt(dataset) Shannon Entropy = 0.970950594455
<3>根据给定的特征来划分数据集
#Dataset splitting on a given feature #feature_axis表示给定特征值的所在的列,feature_value就是feature_axis所对应的元素 #如:dataset = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] #split = SplitDataSet(dataset, 0, 1), split = [[1, 'yes'], [1, 'yes'], [0, 'no']] def SplitDataSet(dataset, feature_axis, feature_value): split_set = [] for featvec in dataset: if featvec[feature_axis] == feature_value: reduced_feature_vector = featvec[:feature_axis]#reduced_feature_vector去除当前特征值,并保存剩余的列值 reduced_feature_vector.extend(featvec[feature_axis+1:])#extend和append接下来会单独说明 split_set.append(reduced_feature_vector) return split_set之所以要重新申请一个列表split_set,是因为Python函数参数传递的是引用,因此,如果在函数内部修改原列表dataset,会导致dataset的改变,因此需要重新申请一个列表。
extend和append两个函数的区别:
a = [1, 2, 3] b = [4, 5, 6] a.extend(b) print a [1, 2, 3, 4, 5, 6] a = [1, 2, 3] a.append(b) print a [1, 2, 3, [4, 5, 6]]由此可知a.extend(b)得到一个包含a和b所有元素的列表,而a.append(b)得到第四个元素,但第四个元素也是一个列表。
测试:
dataset, labels = CreateDataset() splitset = SplitDataSet(dataset, 0, 1) print splitset结果:
[[1, 'yes'], [1, 'yes'], [0, 'no']]<4>选择最好的数据集划分方式
# Choosing the best feature to split on #根据信息增益来选择最佳特征值 def ChooseBestFeature(dataset): feature_number = len(dataset[0])-1#计算数据集中特征的个数 base_entropy = CalculateShannonEnt(dataset)#计算数据集的熵 best_info_gain = 0.0#最大的信息增益 best_feature = -1#最佳特征 for i in xrange(feature_number): feat_list = [example[i] for example in dataset]#取第i列的特征赋值给feat_list unique_value = set(feat_list)#集合唯一性,去除重复的值 temp_entropy = 0.0 for value in unique_value: subset = SplitDataSet(dataset, i, value)#根据第i个的特征值去划分数据集 prob = len(subset)/float(len(dataset))#计算分割后的数据集占总的数据的比 temp_entropy += prob * CalculateShannonEnt(subset) infogain = base_entropy - temp_entropy #print"第 %d 个特征的信息增益"%i, infogain#测试所选的是否信息增益为最大 if infogain > best_info_gain: best_info_gain = infogain best_feature = i; return best_feature
测试:
dataset, labels = CreateDataset() print ChooseBestFeature(dataset)结果:
0这个结果表明,0是最好的用于划分数据集的特征,那么这个选择是否对的呢?将代码中“#print"第 %d 个特征的信息增益"%i, infogain”,这句话去掉注释,既可以得出:
第 0 个特征的信息增益 0.419973094022 第 1 个特征的信息增益 0.170950594455<5>选取数目最多的标签
#对标签投票,找出数目最多的标签 def Majoritycount(classlist): classcount = {} for vote in classlist: if vote not in classcount.keys(): classcount[vote] = 0 classcount[vote] += 1 #按照每个标签的数目大小,进行从大到小排序, sortedclasscount = sorted(classcount.iteritems, key = operator.itemgetter(1), reverse = True) #返回数目最多的标签 return sortedclasscount[0][0]<6>创建树的函数
#Tree-building code def CreateTree(dataset, labels): classlist = [example[-1] for example in dataset]#取出所有的类别 if classlist.count(classlist[0]) == len(dataset):#如果所有的样本属于同一个类别,结束递归 return classlist[0] if len(classlist[0]) == 1: return Majoritycount(classlist) bestfeat = ChooseBestFeature(dataset) best_label = labels[bestfeat] my_tree = {best_label:{}} del(labels[bestfeat]) feat_values = [example[bestfeat] for example in dataset]#选取bestfeat列所对应的特征值 unique_value = set(feat_values) for value in unique_value: sublabels = labels[:]#sublabels存储的是除去最佳特征的labels my_tree[best_label][value] = CreateTree(SplitDataSet(dataset, bestfeat, value), sublabels) return my_tree测试:
dataset, labels = CreateDataset() print CreateTree(dataset, labels)结果:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}转换成图形即:
二、利用Python画图
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 01 11:00:37 2016
@author: G
"""
import matplotlib.pyplot as plt
decision_node = dict(boxstyle = 'sawtooth', fc = "0.8")
leaf_node = dict(boxstyle = 'round4', fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
#读取叶子结点个数,为的就是画树的时候,计算节点和箭头的位置
def GetLeafsNumber(mytree):
leafs_number = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
leafs_number += GetLeafsNumber(first_value[key])
else:
leafs_number += 1
return leafs_number
#读取树的深度
def GetTreeDepth(mytree):
max_depth = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
this_depth = GetTreeDepth( first_value[key])+1
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth
#画节点,属性名字
def PlotNode(node_txt, center_pt, parent_pt, node_type):
CreatePlot.ax1.annotate(node_txt, xy = parent_pt, xycoords = "axes fraction",
xytext = center_pt, textcoords = "axes fraction",
va = "center", ha = "center", bbox = node_type, arrowprops = arrow_args)
#在剪头中间(即树枝上)添加权值(属性值)
def PlotMidText(center_pt, parent_pt, txt_string):
xmid = (parent_pt[0] - center_pt[0])/2.0 + center_pt[0]*1.03#此处乘1.03目的就是不让权值出现在树枝上
ymid = (parent_pt[1] - center_pt[1])/2.0 + center_pt[1]
CreatePlot.ax1.text(xmid, ymid, txt_string, va = "center", ha = "center", rotation = 30)
#画树
def PlotTree(mytree, parent_pt, nodetxt):
leafs_number = GetLeafsNumber(mytree)
first_feature = mytree.keys()[0]
#接下来就是计算center_pt,就是节点的中间位置
center_pt = (PlotTree.xOff + (1.0 + float(leafs_number))/2.0/PlotTree.totalW, PlotTree.yOff)
PlotMidText(center_pt, parent_pt, nodetxt)
PlotNode(first_feature, center_pt, parent_pt, decision_node)
#第一个字典的key对应的value,存储的也就是该节点的子树
first_value = mytree[first_feature]
PlotTree.yOff = PlotTree.yOff - 1.0/PlotTree.totalD
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
PlotTree(first_value[key], center_pt, str(key))
else:
PlotTree.xOff = PlotTree.xOff + 1.0/PlotTree.totalW
PlotNode(first_value[key], (PlotTree.xOff, PlotTree.yOff), center_pt, leaf_node)
PlotMidText((PlotTree.xOff, PlotTree.yOff), center_pt, str(key))
PlotTree.yOff = PlotTree.yOff + 1.0/PlotTree.totalD
def CreatePlot(input_tree):
fig = plt.figure(1, facecolor = "white")
fig.clf()
axprops = dict(xticks = [], yticks = [])
CreatePlot.ax1 = plt.subplot(111, frameon = False, **axprops)
PlotTree.totalW = float(GetLeafsNumber(input_tree))
PlotTree.totalD = float(GetTreeDepth(input_tree))
PlotTree.xOff = -0.5/PlotTree.totalW
PlotTree.yOff = 1.0
PlotTree(input_tree, (0.5, 1.0), '')
plt.show()
'''
def CreatePlot():
fig = plt.figure(1, facecolor = "white")
fig.clf()
CreatePlot.ax1 = plt.subplot(111, frameon = False)
PlotNode("a decision node", (0.5, 0.1), (0.1, 0.5), decision_node)
PlotNode("a leaf node", (0.8, 0.1), (0.3, 0.8), leaf_node)
plt.show()
'''
def RetrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
三、完整代码
<1>DecisionTrees.py
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 31 19:33:15 2016
@author: G
"""
from math import log
import operator
import DecisionTreePlotter as DTP
#create data set and lebels
def CreateDataSet():
dataset = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataset, labels
#Function to calculate the Shannon entropy of a dataset def CalculateShannonEnt(dataset): entropy_num = len(dataset) label_count = {} #计算每个标签对应的个数,是一个集合 for feature_vector in dataset:#feature_vector每一次取dataset的一个列表元素(dataset的元素也是一个列表) current_label = feature_vector[-1]#对应的是当前海洋生物的类别,即标签 if current_label not in label_count.keys():#判断当前的标签是否在label_count中 label_count[current_label] = 0 #为了统一操作接下来的+1 label_count[current_label] += 1 shannonent = 0.0 for key in label_count:#计算香农熵 prob = float(label_count[key])/entropy_num shannonent -= prob*log(prob, 2) return shannonent
#Dataset splitting on a given feature #feature_axis表示给定特征值的所在的列,feature_value就是feature_axis所对应的元素 #如:dataset = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] #split = SplitDataSet(dataset, 0, 1), split = [[1, 'yes'], [1, 'yes'], [0, 'no']] def SplitDataSet(dataset, feature_axis, feature_value): split_set = [] for featvec in dataset: if featvec[feature_axis] == feature_value: reduced_feature_vector = featvec[:feature_axis]#reduced_feature_vector去除当前特征值,并保存剩余的列值 reduced_feature_vector.extend(featvec[feature_axis+1:])#extend和append接下来会单独说明 split_set.append(reduced_feature_vector) return split_set# Choosing the best feature to split on
#根据信息增益来选择最佳特征值
def ChooseBestFeature(dataset):
feature_number = len(dataset[0])-1#计算数据集中特征的个数
base_entropy = CalculateShannonEnt(dataset)#计算数据集的熵
best_info_gain = 0.0#最大的信息增益
best_feature = -1#最佳特征
for i in xrange(feature_number):
feat_list = [example[i] for example in dataset]#取第i列的特征赋值给feat_list
unique_value = set(feat_list)#集合唯一性,去除重复的值
temp_entropy = 0.0
for value in unique_value:
subset = SplitDataSet(dataset, i, value)#根据第i个的特征值去划分数据集
prob = len(subset)/float(len(dataset))#计算分割后的数据集占总的数据的比
temp_entropy += prob * CalculateShannonEnt(subset)
infogain = base_entropy - temp_entropy
#print"第 %d 个特征的信息增益"%i, infogain
if infogain > best_info_gain:
best_info_gain = infogain
best_feature = i;
return best_feature#返回的是最佳特征的id
#对标签投票,找出classlist中数目最多的标签
def Majoritycount(classlist):
classcount = {}
for vote in classlist:
if vote not in classcount.keys():
classcount[vote] = 0
classcount[vote] += 1
#按照每个标签的数目大小,进行从大到小排序,
sortedclasscount = sorted(classcount.iteritems, key = operator.itemgetter(1), reverse = True)
#返回数目最多的标签
return sortedclasscount[0][0]
#Tree-building code def CreateTree(dataset, labels): classlist = [example[-1] for example in dataset]#取出所有的类别 if classlist.count(classlist[0]) == len(dataset):#如果所有的样本属于同一个类别,结束递归 return classlist[0] if len(classlist[0]) == 1: return Majoritycount(classlist) bestfeat = ChooseBestFeature(dataset) best_label = labels[bestfeat] my_tree = {best_label:{}} del(labels[bestfeat]) feat_values = [example[bestfeat] for example in dataset]#选取bestfeat列所对应的特征值 unique_value = set(feat_values) for value in unique_value: sublabels = labels[:]#sublabels存储的是除去最佳特征的labels my_tree[best_label][value] = CreateTree(SplitDataSet(dataset, bestfeat, value), sublabels) return my_tree
#对于给定的数据实现分类
def Classify(input_tree, feat_labels, test_vector):
first_feature = input_tree.keys()[0]
first_value = input_tree[first_feature]
feature_index = feat_labels.index(first_feature)
for key in first_value.keys():
if test_vector[feature_index] == key:
if type(first_value[key]).__name__ == 'dict':#此句成立说明有子树
class_label = Classify(first_value[key], feat_labels, test_vector)
else:
class_label = first_value[key]
return class_label
#为了每次分类的时候,不必重新建树,从而节省时间
def StoreTree(input_tree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(input_tree, fw)
fw.close()
def ReadTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
#测试样例
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lenses_labels = ['age', 'prescript', 'astigmatic1', 'tearRate1']
lense_tree = CreateTree(lenses, lenses_labels)
print lense_tree
DTP.CreatePlot(lense_tree)
''''
dataset, labels = CreateDataSet()
print labels
mytree = DTP.RetrieveTree(0)
print Classify(mytree, labels, [1, 0])
'''
<2DecisionTreePlotter.py>
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 01 11:00:37 2016
@author: G
"""
import matplotlib.pyplot as plt
decision_node = dict(boxstyle = 'sawtooth', fc = "0.8")
leaf_node = dict(boxstyle = 'round4', fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
#读取叶子结点个数,为的就是画树的时候,计算节点和箭头的位置
def GetLeafsNumber(mytree):
leafs_number = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
leafs_number += GetLeafsNumber(first_value[key])
else:
leafs_number += 1
return leafs_number
#读取树的深度
def GetTreeDepth(mytree):
max_depth = 0
fist_feature = mytree.keys()[0]
first_value = mytree[fist_feature]
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
this_depth = GetTreeDepth( first_value[key])+1
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth
#画节点,属性名字
def PlotNode(node_txt, center_pt, parent_pt, node_type):
CreatePlot.ax1.annotate(node_txt, xy = parent_pt, xycoords = "axes fraction",
xytext = center_pt, textcoords = "axes fraction",
va = "center", ha = "center", bbox = node_type, arrowprops = arrow_args)
#在剪头中间(即树枝上)添加权值(属性值)
def PlotMidText(center_pt, parent_pt, txt_string):
xmid = (parent_pt[0] - center_pt[0])/2.0 + center_pt[0]*1.03#此处乘1.03目的就是不让权值出现在树枝上
ymid = (parent_pt[1] - center_pt[1])/2.0 + center_pt[1]
CreatePlot.ax1.text(xmid, ymid, txt_string, va = "center", ha = "center", rotation = 30)
#画树
def PlotTree(mytree, parent_pt, nodetxt):
leafs_number = GetLeafsNumber(mytree)
first_feature = mytree.keys()[0]
#接下来就是计算center_pt,就是节点的中间位置
center_pt = (PlotTree.xOff + (1.0 + float(leafs_number))/2.0/PlotTree.totalW, PlotTree.yOff)
PlotMidText(center_pt, parent_pt, nodetxt)
PlotNode(first_feature, center_pt, parent_pt, decision_node)
#第一个字典的key对应的value,存储的也就是该节点的子树
first_value = mytree[first_feature]
PlotTree.yOff = PlotTree.yOff - 1.0/PlotTree.totalD
for key in first_value.keys():
if type(first_value[key]).__name__ == 'dict':
PlotTree(first_value[key], center_pt, str(key))
else:
PlotTree.xOff = PlotTree.xOff + 1.0/PlotTree.totalW
PlotNode(first_value[key], (PlotTree.xOff, PlotTree.yOff), center_pt, leaf_node)
PlotMidText((PlotTree.xOff, PlotTree.yOff), center_pt, str(key))
PlotTree.yOff = PlotTree.yOff + 1.0/PlotTree.totalD
def CreatePlot(input_tree):
fig = plt.figure(1, facecolor = "white")
fig.clf()
axprops = dict(xticks = [], yticks = [])
CreatePlot.ax1 = plt.subplot(111, frameon = False, **axprops)
PlotTree.totalW = float(GetLeafsNumber(input_tree))
PlotTree.totalD = float(GetTreeDepth(input_tree))
PlotTree.xOff = -0.5/PlotTree.totalW
PlotTree.yOff = 1.0
PlotTree(input_tree, (0.5, 1.0), '')
plt.show()
'''
def CreatePlot():
fig = plt.figure(1, facecolor = "white")
fig.clf()
CreatePlot.ax1 = plt.subplot(111, frameon = False)
PlotNode("a decision node", (0.5, 0.1), (0.1, 0.5), decision_node)
PlotNode("a leaf node", (0.8, 0.1), (0.3, 0.8), leaf_node)
plt.show()
'''
def RetrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
<4>测试结果:
lense_tree = {'tearRate1': {'reduced': 'no lenses', 'normal': {'astigmatic1': {'yes':
{'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}},
'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript':
{'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}
实验数据:http://download.csdn.net/detail/hearthougan/9664933
ID3决策树有很多的不足,比如容易导致过拟合现象,为此需要一些剪枝策略,具体请参考周志华老师的《机器学习》:
http://download.csdn.net/detail/hearthougan/9652865
相关文章推荐
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 【机器学习算法-python实现】决策树-Decision tree(2) 决策树的实现
- Machine Learning In Action -- ID3决策树学习算法的python实现
- 归纳决策树ID3(Java实现)
- 数据挖掘—决策树ID3分类算法的C++实现
- 决策树归纳(ID3属性选择度量)Java实现
- 【机器学习算法-python实现】决策树-Decision tree(1) 信息熵划分数据集
- 【机器学习算法-python实现】决策树-Decision tree(2) 决策树的实现
- 机器学习经典算法详解及Python实现--决策树(Decision Tree)
- 决策树ID3分类算法的C++实现
- ID3 算法实现决策树
- 决策树学习 之 ID3 C++STL代码实现
- 决策树ID3和C4.5算法Python实现源码
- 【机器学习算法-python实现】决策树-Decision tree(1) 信息熵划分数据集
- 决策树ID3和C4.5算法Python实现源码
- 决策树ID3(Java实现)
- 数据挖掘-决策树ID3分类算法的C++实现
- 【机器学习算法-python实现】决策树-Decision tree(1) 信息熵划分数据集
- 决策树ID3;C4.5详解和python实现与R语言实现比较