您的位置:首页 > 其它

西瓜书 课后习题4.3 基于信息熵决策树,连续和离散属性,并验证模型

2018-11-13 10:52 274 查看
[code]import matplotlib.pyplot as plt
import numpy as np
from math import log
import operator
import csv

def readDataset(filename):
'''
读取数据
:param filename: 数据文件名,CSV格式
:return:  以列表形式返回数据列表和特征列表
'''
with open(filename) as f:
reader = csv.reader(f)
header_row = next(reader)
labels = header_row[1:9]
dataset = []
for line in reader:
tempVect = line[1:10]
dataset.append(tempVect)
return dataset, labels

def infoEnt(dataset):
'''
计算信息熵
:param dataset:  输入数据集
:return:  返回信息熵
'''
numdata = len(dataset)
labels = {}
for featVec in dataset:
label = featVec[-1]
if label not in labels.keys():
labels[label] = 0
labels[label] += 1
infoEnt = 0
for lab in labels.keys():
prop = float(labels[lab]) / numdata
infoEnt -= (prop * log(prop, 2))
return infoEnt

def bestFeatureSplit(dataset):
'''
最优属性划分
:param dataset: 输入需要划分的数据集
:return:  返回最优划分属性的下标
'''
numFeature = len(dataset[0]) - 1
baseInfoEnt = infoEnt(dataset)
bestInfoGain = 0
bestFeature = -1
bestSplitPoint = None
continuous = False
for i in range(numFeature):
featList = [example[i] for example in dataset]
newEnt = 0
if all(c in "0123456789.-" for c in featList[0]):  # 连续属性
continuous = True
featList.sort()
tempFeatList = [float(feat) for feat in featList]  # 字符串转换成数字,用set(featList)会出现结果不稳定
mediumPoints = []
for index in range(len(tempFeatList) - 1):
mediumPoints.append((tempFeatList[index] + tempFeatList[index + 1]) / 2)
for point in mediumPoints:
for part in range(2):
subDataset = splitDataset(dataset, i, point, continuous, part)
prop = len(subDataset) / float(len(dataset))
newEnt += prop * infoEnt(subDataset)
infoGain = baseInfoEnt - newEnt
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
bestSplitPoint = point
else:
uniqueValue = set(featList)
for value in uniqueValue:
subDataset = splitDataset(dataset, i, value, continuous)
prop = len(subDataset) / float(len(dataset))
newEnt += prop * infoEnt(subDataset)
infoGain = baseInfoEnt - newEnt
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
bestSplitPoint = None
return bestFeature, bestSplitPoint

def splitDataset(dataset, axis, value, continuous, part=0):
'''
对某个特征进行划分后的数据集
:param dataset: 数据集
:param axis: 划分属性的下标
:param value: 划分属性值
:return: 返回剩余数据集
'''
restDataset = []
if continuous == True:  # 连续属性
for featVec in dataset:
if part == 0 and float(featVec[axis]) <= value:
restFeatVec = featVec[:axis]
restFeatVec.extend(featVec[axis + 1:])
restDataset.append(restFeatVec)
if part == 1 and float(featVec[axis]) > value:
restFeatVec = featVec[:axis]
restFeatVec.extend(featVec[axis + 1:])
restDataset.append(restFeatVec)
else:  # 离散属性
for featVec in dataset:
if featVec[axis] == value:
restFeatVec = featVec[:axis]
restFeatVec.extend(featVec[axis + 1:])
restDataset.append(restFeatVec)
return restDataset

def majorClass(classList):
'''
对叶节点的分类结果进行数量投票划分
:param classList:  叶节点上的样本数量
:return: 返回叶节点划分结果
'''
classCount = {}
for vote in classList:
if vote not in classCount:
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)  # 返回数组
return sortedClassCount[0][0]

def createTree(dataset, labels, datasetFull, labelsFull):
'''
递归创建决策树
:param dataset: 数据集列表
:param labels:  标签集列表
:param datasetFull: 数据集列表,再传一次
:param labelsFull:  标签集列表,再传一次
:return: 返回决策树字典
'''
classList = [example[-1] for example in dataset]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataset[0]) == 1:
return (majorClass(classList))
bestFeat, bestSplitPoint = bestFeatureSplit(dataset)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
# 创建所有属性标签的所有值,以防漏掉某些取值,例如西瓜数据集2.0中的  色泽:浅白
bestFeatIndex = labelsFull.index(bestFeatLabel)
featValuesFull = [example[bestFeatIndex] for example in datasetFull]
uniqueValFull = set(featValuesFull)
if bestSplitPoint == None:  # 离散节点
del (labels[bestFeat])
featValues = [example[bestFeat] for example in dataset]
uniqueVal = set(featValues)
if uniqueVal == uniqueValFull:
for value in uniqueVal:
subLabels = labels[:]  # 递归回退过程需要继续使用标签,所以前行过程标签副本
myTree[bestFeatLabel][value] = createTree(splitDataset(dataset, bestFeat, value, False),
subLabels, datasetFull, labelsFull)
else:
for value in uniqueVal:
subLabels = labels[:]  # 递归回退过程需要继续使用标签,所以前行过程标签副本
myTree[bestFeatLabel][value] = createTree(splitDataset(dataset, bestFeat, value, False),
subLabels, datasetFull, labelsFull)
uniqueValFull.remove(value)
for value in uniqueValFull:
myTree[bestFeatLabel][value] = majorClass(classList)
else:  # 连续节点
subLabels = labels[:]
myTree[bestFeatLabel]['<=' + str(bestSplitPoint)] = createTree(
splitDataset(dataset, bestFeat, bestSplitPoint, True, 0), subLabels, datasetFull, labelsFull)
subLabels = labels[:]
myTree[bestFeatLabel]['>' + str(bestSplitPoint)] = createTree(
splitDataset(dataset, bestFeat, bestSplitPoint, True, 1), subLabels, datasetFull, labelsFull)
return myTree

def decideTreePredict(decideTree, featList, testData):
'''
决策树预测
:param decideTree: 决策树模型
:param featList: 特征列表
:param testData: 测试数据
:return: 返回预测结果
'''
firstFeat = list(decideTree.keys())[0]
secDict = decideTree[firstFeat]
featIndex = featList.index(firstFeat)
decideLabel = None
for key in secDict.keys():
if key[0] == '<':
value = float(key[2:])
if float(testData[featIndex]) <= value:
if type(secDict[key]).__name__ == 'dict':
decideLabel = decideTreePredict(secDict[key], featList, testData)
else:
decideLabel = secDict[key]
elif key[0] == '>':
value = float(key[1:])
if float(testData[featIndex]) > value:
if type(secDict[key]).__name__ == 'dict':
decideLabel = decideTreePredict(secDict[key], featList, testData)
else:
decideLabel = secDict[key]

else:
if testData[featIndex] == key:
if type(secDict[key]).__name__ == 'dict':
decideLabel = decideTreePredict(secDict[key], featList, testData)
else:
decideLabel = secDict[key]
return decideLabel

if __name__ == '__main__':
filename = 'C:\\Users\\14399\\Desktop\\西瓜3.0.csv'
dataset, labels = readDataset(filename)
datasetFull = dataset[:]
labelsFull = labels[:]
myTree = (createTree(dataset, labels, datasetFull, labelsFull))
print(myTree)
# 验证结果,这里用的原来训练集数据,所以为100%正确
count = 0
for testData in dataset:
if decideTreePredict(myTree, labelsFull, testData) == testData[-1]:
count += 1
print(count)

生成结果:{'纹理': {'模糊': '否', '清晰': {'根蒂': {'硬挺': '否', '蜷缩': '是', '稍蜷': {'密度': {'<=0.3815': '否', '>0.3815': '是'}}}}, '稍糊': {'触感': {'软粘': '是', '硬滑': '否'}}}}    ( 与书中结果略有不同,但不影响正确率。)

西瓜3.0数据集:链接:https://pan.baidu.com/s/1RXTUG9gP1Jn9HKFCiEzOlA         密码:3h6n

参考: https://www.geek-share.com/detail/2729164247.html (含画树算法)

         https://www.geek-share.com/detail/2701540653.html

         https://blog.csdn.net/icefire_tyh/article/details/54575527

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: