Python实现决策树算法
2015-06-30 21:00
639 查看
# -*-coding:utf-8-*- ''' 决策树算法 ''' from __future__ import division import matplotlib.pyplot as plt from math import log import operator import pickle def createDataSet(): dataset = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']] labels = ['no sufacing','flippers'] return dataset,labels def calShannonEnt(dataSet): ''' 计算给定数据集的熵 ''' numEntries = len(dataSet) # 数据集 labCounts = {} # 存储每一类数量 for _featVec_ in dataSet: # 对每一训练数据 _currentLabel_ = _featVec_[-1] # 分类标 if _currentLabel_ not in labCounts.keys(): labCounts[_currentLabel_] = 1 else: labCounts[_currentLabel_] += 1 shannonEnt = 0.0 for _key_ in labCounts: # 对每一分类,计算熵 _prob = float(labCounts[_key_]) / numEntries # P(x) shannonEnt -= _prob * log(_prob,2) # P(x) * log(P(x)) return shannonEnt def splitDataSet(dataset,axis,value): ''' 按照给定特征划分数据集 ''' retDataSet = [] for _featVec_ in dataset: # 每一训练数据 if _featVec_[axis] == value: # 判断特征值 ?= 指定值 reducedFeatVec = _featVec_[:axis] # 在新列表中加载除该特征前面的所有特征 reducedFeatVec.extend(_featVec_[axis+1:]) # 加载该特征值后面的所有特征 retDataSet.append(reducedFeatVec) return retDataSet # 满足特征 = 指定值的数据,分割完成 def chooseBestFeatureToSplit(dataSet): ''' 根据信息增益最大,选择最好的数据集划分特征 ''' _numFeatures = len(dataSet[0]) - 1 # 训练集特征个数 _baseEntropy = calShannonEnt(dataSet) # 数据集的熵 _bestInfoGain = 0.0 # 信息增益 bestFeature = -1 # 最优特征 for i in range(_numFeatures): # 对数据的每一个特征 featList = [example[i] for example in dataSet] # 提取所有训练样本中第i个特征 ---> list _uniqueVals = set(featList) # 特征值的所有取值 _newEntropy = 0.0 for _value_ in _uniqueVals: # 计算该特征下的熵 _subDataSet = splitDataSet(dataSet,i,_value_) # 按照特征i分割数据 _prob = len(_subDataSet) / float(len(dataSet)) # 特征i下,分别取不同特征值的概率p() _newEntropy += _prob * calShannonEnt(_subDataSet) # 计 4000 算特征i的熵 infoGain = _baseEntropy - _newEntropy # 特征值i的信息增益 if (infoGain > _bestInfoGain): # 取最大信息增益时的特征i _bestInfoGain = infoGain bestFeature = i return bestFeature # 返回数据集的最优分割特征 def majorityCnt(classList): ''' 最多数决定叶子节点的分类 ''' _classCount = {} for _vote_ in classList: if _vote_ not in _classCount.keys(): _classCount[_vote_] = 0 _classCount[_vote_] += 1 sortedClassCount = sorted(_classCount.iteritems(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0] # 排序后返回出现次数最多的分类名称 def createTree(dataSet,labels): ''' 创建树,数据集和特征标签 ''' _label = [l for l in labels] _classList = [_example_[-1] for _example_ in dataSet] # 数据集的所有分类标签列表 if _classList.count(_classList[0]) == len(_classList): # 只有一个分类标签,结束,返回 return _classList[0] if len(dataSet[0]) == 1: # 如果训练数据集只有一列,必定是分类标签,返回其中出现次数最多的分类 return majorityCnt(_classList) _bestFeat = chooseBestFeatureToSplit(dataSet) # 信息增益最大的特征 _bestFeatLabel = _label[_bestFeat] # 信息增益最大的特征标签 myTree = {_bestFeatLabel:{}} # 开始建树 del(_label[_bestFeat]) # 将已经建树的特征从数据集中删除 _featValues = [_example_[_bestFeat] for _example_ in dataSet] # 特征值列表 uniqueVals = set(_featValues) # 特征值的不同取值 for _value_ in uniqueVals: subLabels = _label[:] # 对特征的每一个取值,建支树 myTree[_bestFeatLabel][_value_] = createTree(splitDataSet(dataSet,_bestFeat,_value_),subLabels) return myTree def classify(inputTree,featLabels,testVec): ''' 根据训练决策树,判断测试向量testVec ''' _firstStr = inputTree.keys()[0] # 树根节点 _secondDict = inputTree[_firstStr] # 支树 _featIndex = featLabels.index(_firstStr) # 树根节点 ---> 特征位置 ---> 测试向量位置 for key in _secondDict.keys(): if testVec[_featIndex] == key: if type(_secondDict[key]).__name__ == 'dict': classLabel = classify(_secondDict[key],featLabels,testVec) else: classLabel = secondDict[key] return classLabel def storeTree(inputTree,filename): ''' 保存决策树 ''' f = open(filename,'w') pickle.dump(inputTree,f) f.close() def grabTree(filename): ''' 从文件中取出决策树 ''' f = open(filename) return pickle.load(f) file = 'lenses.txt' lenses = [ins.strip().split('\t') for ins in open(file).readlines()] lensesLabels = ['age','prescript','astigmatic','tearRate'] lensesTree = createTree(lenses,lensesLabels)参考:《Machine Learning in Action》
相关文章推荐
- Python动态类型的学习---引用的理解
- Python3写爬虫(四)多线程实现数据爬取
- 垃圾邮件过滤器 python简单实现
- 下载并遍历 names.txt 文件,输出长度最长的回文人名。
- install and upgrade scrapy
- Scrapy的架构介绍
- Centos6 编译安装Python
- 使用Python生成Excel格式的图片
- 让Python文件也可以当bat文件运行
- [Python]推算数独
- Python中zip()函数用法举例
- Python中map()函数浅析
- Python在CAM软件Genesis2000中的应用
- 使用Shiboken为C++和Qt库创建Python绑定
- FREEBASIC 编译可被python调用的dll函数示例
- Python 七步捉虫法
- Python实现的基于ADB的Android远程工具