如何实现并应用决策树算法?
2016-04-24 23:28
603 查看
本文对决策树算法进行简单的总结和梳理,并对著名的决策树算法ID3(Iterative Dichotomiser 迭代二分器)进行实现,实现采用Python语言,一句老梗,“人生苦短,我用Python”,Python确实能够省很多语言方面的事,从而可以让我们专注于问题和解决问题的逻辑。
根据不同的数据,我实现了三个版本的ID3算法,复杂度逐步提升:
1.纯标称值无缺失数据集
2.连续值和标称值混合且无缺失数据集
3.连续值和标称值混合,有缺失数据集
第一个算法参考了《机器学习实战》的大部分代码,第二、三个算法基于前面的实现进行模块的增加。
决策树是一种监督学习的分类算法,目的是学习出一颗决策树,该树中间节点是数据特征,叶子节点是类别,实际分类时根据树的结构,一步一步根据当前数据特征取值选择进入哪一颗子树,直到走到叶子节点,叶子节点的类别就是此决策树对此数据的学习结果。下图就是一颗简单的决策树:
View Code
有缺失值的情况如 西瓜数据集2.0alpha
实验结果:
数据总共有9列,每一列分别代表,以逗号分割
1 Sample code number (病人ID)
2 Clump Thickness 肿块厚度
3 Uniformity of Cell Size 细胞大小的均匀性
4 Uniformity of Cell Shape 细胞形状的均匀性
5 Marginal Adhesion 边缘粘
6 Single Epithelial Cell Size 单上皮细胞的大小
7 Bare Nuclei 裸核
8 Bland Chromatin 乏味染色体
9 Normal Nucleoli 正常核
10 Mitoses 有丝分裂
11 Class: 2 for benign, 4 formalignant(恶性或良性分类)
[from Toby]
总共700条左右的数据,选取最后80条作为测试集,前面作为训练集,进行学习。
使用分类器的代码如下:
训练出的决策树如下:
最终的正确率可以看到:
正确率约为96%左右,算是不差的分类器了。
我的乳腺癌数据见:http://7xt9qk.com2.z0.glb.clouddn.com/breastcancer.txt
至此,决策树算法ID3的实现完毕,下面考虑基于基尼指数和信息增益率进行划分选择,以及考虑实现剪枝过程,因为我们可以看到上面训练出的决策树还存在着很多冗余分支,是因为实现过程中,由于数据量太大,每个分支都不完全纯净,所以会创建往下的分支,但是分支投票的结果又是一致的,而且数据量再大,特征数再多的话,决策树会非常大非常复杂,所以剪枝一般是必做的一步。剪枝分为先剪枝和后剪枝,如果细说的话可以写很多了。
此文亦可见:这里
参考资料:《机器学习》《机器学习实战》通过本次实战也发现了这两本书中的一些错误之处。
lz初学机器学习不久,如有错漏之处请多包涵指出或者各位有什么想法或意见欢迎评论去告诉我:)
根据不同的数据,我实现了三个版本的ID3算法,复杂度逐步提升:
1.纯标称值无缺失数据集
2.连续值和标称值混合且无缺失数据集
3.连续值和标称值混合,有缺失数据集
第一个算法参考了《机器学习实战》的大部分代码,第二、三个算法基于前面的实现进行模块的增加。
决策树简介
决策树算法不用说大家应该都知道,是机器学习的一个著名算法,由澳大利亚著名计算机科学家Rose Quinlan发表。决策树是一种监督学习的分类算法,目的是学习出一颗决策树,该树中间节点是数据特征,叶子节点是类别,实际分类时根据树的结构,一步一步根据当前数据特征取值选择进入哪一颗子树,直到走到叶子节点,叶子节点的类别就是此决策树对此数据的学习结果。下图就是一颗简单的决策树:
from math import log from operator import itemgetter def filetoDataSet(filename): fr = open(filename,'r') all_lines = fr.readlines() featname = all_lines[0].strip().split(',')[1:-1] dataSet = [] for line in all_lines[1:]: line = line.strip() lis = line.split(',')[1:] if lis[-1] == '2': lis[-1] = '良' else: lis[-1] = '恶' dataSet.append(lis) fr.close() return dataSet,featname def calcEnt(dataSet, weight): #计算权重香农熵 labelCounts = {} i = 0 for featVec in dataSet: label = featVec[-1] if label not in labelCounts.keys(): labelCounts[label] = 0 labelCounts[label] += weight[i] i += 1 Ent = 0.0 for key in labelCounts.keys(): p_i = float(labelCounts[key]/sum(weight)) Ent -= p_i * log(p_i,2) return Ent def splitDataSet(dataSet, weight, axis, value, countmissvalue): #划分数据集,找出第axis个属性为value的数据 returnSet = [] returnweight = [] i = 0 for featVec in dataSet: if featVec[axis] == '?' and (not countmissvalue): continue if countmissvalue and featVec[axis] == '?': retVec = featVec[:axis] retVec.extend(featVec[axis+1:]) returnSet.append(retVec) if featVec[axis] == value: retVec = featVec[:axis] retVec.extend(featVec[axis+1:]) returnSet.append(retVec) returnweight.append(weight[i]) i += 1 return returnSet,returnweight def splitDataSet_for_dec(dataSet, axis, value, small, countmissvalue): returnSet = [] for featVec in dataSet: if featVec[axis] == '?' and (not countmissvalue): continue if countmissvalue and featVec[axis] == '?': retVec = featVec[:axis] retVec.extend(featVec[axis+1:]) returnSet.append(retVec) if (small and featVec[axis] <= value) or ((not small) and featVec[axis] > value): retVec = featVec[:axis] retVec.extend(featVec[axis+1:]) returnSet.append(retVec) return returnSet def DataSetPredo(filename,decreteindex): #首先运行,权重不变为1 dataSet,featname = filetoDataSet(filename) DataSetlen = len(dataSet) Entropy = calcEnt(dataSet,[1 for i in range(DataSetlen)]) for index in decreteindex: #对每一个是连续值的属性下标 UnmissDatalen = 0 for i in range(DataSetlen): #字符串转浮点数 if dataSet[i][index] != '?': UnmissDatalen += 1 dataSet[i][index] = int(dataSet[i][index]) allvalue = [vec[index] for vec in dataSet if vec[index] != '?'] sortedallvalue = sorted(allvalue) T = [] for i in range(len(allvalue)-1): #划分点集合 T.append(int(sortedallvalue[i]+sortedallvalue[i+1])/2.0) bestGain = 0.0 bestpt = -1.0 for pt in T: #对每个划分点 nowent = 0.0 for small in range(2): #化为正类(1)负类(0) Dt = splitDataSet_for_dec(dataSet, index, pt, small, False) p = len(Dt) / float(UnmissDatalen) nowent += p * calcEnt(Dt,[1.0 for i in range(len(Dt))]) if Entropy - nowent > bestGain: bestGain = Entropy-nowent bestpt = pt featname[index] = str(featname[index]+"<="+"%d"%bestpt) for i in range(DataSetlen): if dataSet[i][index] != '?': dataSet[i][index] = "是" if dataSet[i][index] <= bestpt else "否" return dataSet,featname def getUnmissDataSet(dataSet, weight, axis): returnSet = [] returnweight = [] tag = [] i = 0 for featVec in dataSet: if featVec[axis] == '?': tag.append(i) else: retVec = featVec[:axis] retVec.extend(featVec[axis+1:]) returnSet.append(retVec) i += 1 for i in range(len(weight)): if i not in tag: returnweight.append(weight[i]) return returnSet,returnweight def printlis(lis): for li in lis: print(li) def chooseBestFeat(dataSet,weight,featname): numFeat = len(dataSet[0])-1 DataSetWeight = sum(weight) bestGain = 0.0 bestFeat = -1 for i in range(numFeat): UnmissDataSet,Unmissweight = getUnmissDataSet(dataSet, weight, i) #无缺失值数据集及其权重 Entropy = calcEnt(UnmissDataSet,Unmissweight) #Ent(D~) allvalue = [featVec[i] for featVec in dataSet if featVec[i] != '?'] UnmissSumWeight = sum(Unmissweight) lou = UnmissSumWeight / DataSetWeight #lou specvalue = set(allvalue) nowEntropy = 0.0 for v in specvalue: #该属性的几种取值 Dv,weightVec_v = splitDataSet(dataSet,Unmissweight,i,v,False) #返回 此属性为v的所有样本 以及 每个样本的权重 p = sum(weightVec_v) / UnmissSumWeight #r~_v = D~_v / D~ nowEntropy += p * calcEnt(Dv,weightVec_v) if lou*(Entropy - nowEntropy) > bestGain: bestGain = Entropy - nowEntropy bestFeat = i return bestFeat def Vote(classList,weight): classdic = {} i = 0 for vote in classList: if vote not in classdic.keys(): classdic[vote] = 0 classdic[vote] += weight[i] i += 1 sortedclassDic = sorted(classdic.items(),key=itemgetter(1),reverse=True) return sortedclassDic[0][0] def splitDataSet_adjustWeight(dataSet,weight,axis,value,r_v): returnSet = [] returnweight = [] i = 0 for featVec in dataSet: if featVec[axis] == '?': retVec = featVec[:axis] retVec.extend(featVec[axis+1:]) returnSet.append(retVec) returnweight.append(weight[i] * r_v) elif featVec[axis] == value: retVec = featVec[:axis] retVec.extend(featVec[axis+1:]) returnSet.append(retVec) returnweight.append(weight[i]) i += 1 return returnSet,returnweight def createDecisionTree(dataSet,weight,featnames): featname = featnames[:] ################ classlist = [featvec[-1] for featvec in dataSet] #此节点的分类情况 if classlist.count(classlist[0]) == len(classlist): #全部属于一类 return classlist[0] if len(dataSet[0]) == 1: #分完了,没有属性了 return Vote(classlist,weight) #少数服从多数 # 选择一个最优特征进行划分 bestFeat = chooseBestFeat(dataSet,weight,featname) bestFeatname = featname[bestFeat] del(featname[bestFeat]) #防止下标不准 DecisionTree = {bestFeatname:{}} # 创建分支,先找出所有属性值,即分支数 allvalue = [vec[bestFeat] for vec in dataSet if vec[bestFeat] != '?'] specvalue = sorted(list(set(allvalue))) #使有一定顺序 UnmissDataSet,Unmissweight = getUnmissDataSet(dataSet, weight, bestFeat) #无缺失值数据集及其权重 UnmissSumWeight = sum(Unmissweight) # D~ for v in specvalue: copyfeatname = featname[:] Dv,weightVec_v = splitDataSet(dataSet,Unmissweight,bestFeat,v,False) #返回 此属性为v的所有样本 以及 每个样本的权重 r_v = sum(weightVec_v) / UnmissSumWeight #r~_v = D~_v / D~ sondataSet,sonweight = splitDataSet_adjustWeight(dataSet,weight,bestFeat,v,r_v) DecisionTree[bestFeatname][v] = createDecisionTree(sondataSet,sonweight,copyfeatname) return DecisionTree if __name__ == '__main__': filename = "D:\\MLinAction\\Data\\breastcancer.txt" DataSet,featname = DataSetPredo(filename,[0,1,2,3,4,5,6,7,8]) Tree = createDecisionTree(DataSet,[1.0 for i in range(len(DataSet))],featname) print(Tree)
View Code
有缺失值的情况如 西瓜数据集2.0alpha
实验结果:
在乳腺癌数据集上的测试与表现
有了算法,我们当然想做一定的测试看一看算法的表现。这里我选择了威斯康辛女性乳腺癌的数据。数据总共有9列,每一列分别代表,以逗号分割
1 Sample code number (病人ID)
2 Clump Thickness 肿块厚度
3 Uniformity of Cell Size 细胞大小的均匀性
4 Uniformity of Cell Shape 细胞形状的均匀性
5 Marginal Adhesion 边缘粘
6 Single Epithelial Cell Size 单上皮细胞的大小
7 Bare Nuclei 裸核
8 Bland Chromatin 乏味染色体
9 Normal Nucleoli 正常核
10 Mitoses 有丝分裂
11 Class: 2 for benign, 4 formalignant(恶性或良性分类)
[from Toby]
总共700条左右的数据,选取最后80条作为测试集,前面作为训练集,进行学习。
使用分类器的代码如下:
import treesID3 as id3 import treePlot as tpl import pickle def classify(Tree, featnames, X): classLabel = "未知" root = list(Tree.keys())[0] firstGen = Tree[root] featindex = featnames.index(root) #根节点的属性下标 for key in firstGen.keys(): #根属性的取值,取哪个就走往哪颗子树 if X[featindex] == key: if type(firstGen[key]) == type({}): classLabel = classify(firstGen[key],featnames,X) else: classLabel = firstGen[key] return classLabel def StoreTree(Tree,filename): fw = open(filename,'wb') pickle.dump(Tree,fw) fw.close() def ReadTree(filename): fr = open(filename,'rb') return pickle.load(fr) if __name__ == '__main__': filename = "D:\\MLinAction\\Data\\breastcancer.txt" dataSet,featnames = id3.DataSetPredo(filename,[0,1,2,3,4,5,6,7,8]) Tree = id3.createDecisionTree(dataSet[:620],[1.0 for i in range(len(dataSet))],featnames) tpl.createPlot(Tree) storetree = "D:\\MLinAction\\Data\\decTree.dect" StoreTree(Tree,storetree) #Tree = ReadTree(storetree) i = 1 cnt = 0 for lis in dataSet[620:]: judge = classify(Tree,featnames,lis[:-1]) shouldbe = lis[-1] if judge == shouldbe: cnt += 1 print("Test %d was classified %s, it's class is %s %s" %(i,judge,shouldbe,"=====" if judge==shouldbe else "")) i += 1 print("The Tree's Accuracy is %.3f" % (cnt / float(i)))
训练出的决策树如下:
最终的正确率可以看到:
正确率约为96%左右,算是不差的分类器了。
我的乳腺癌数据见:http://7xt9qk.com2.z0.glb.clouddn.com/breastcancer.txt
至此,决策树算法ID3的实现完毕,下面考虑基于基尼指数和信息增益率进行划分选择,以及考虑实现剪枝过程,因为我们可以看到上面训练出的决策树还存在着很多冗余分支,是因为实现过程中,由于数据量太大,每个分支都不完全纯净,所以会创建往下的分支,但是分支投票的结果又是一致的,而且数据量再大,特征数再多的话,决策树会非常大非常复杂,所以剪枝一般是必做的一步。剪枝分为先剪枝和后剪枝,如果细说的话可以写很多了。
此文亦可见:这里
参考资料:《机器学习》《机器学习实战》通过本次实战也发现了这两本书中的一些错误之处。
lz初学机器学习不久,如有错漏之处请多包涵指出或者各位有什么想法或意见欢迎评论去告诉我:)
相关文章推荐
- Redis源码分析——链表
- 个人工作总结7
- 剑桥大学的教育
- 使用git远程仓库
- 深度学习笔记——参考条目
- 浅谈Swift和Objective-C之间的那点事。。。
- 冲刺阶段第六天,4月24日。
- 安装交叉编译环境
- njust 1928 puzzle (2-sat)
- Eclipse常用快捷键
- android studio 编译的时候出现的错误和解决方法
- PHP7之Closure::call()
- 常用排序算法之快速排序
- Java 浅拷贝和深拷贝那些事
- 20145330Java程序设计第三次实验
- iOS环信集成<1>
- 使用微软的TFS团队开发
- RNN学习笔记:Understanding Deep Architectures using a Recursive Convolutional Network
- Android 简单自定义view--倒计时
- 史上最全最强SpringMVC详细示例实战教程