您的位置:首页 > 编程语言

CART分类回归树-Gini系数 基于原理的代码练习

2020-08-28 20:28 1271 查看

特征选择,决策树生成,决策树剪枝

特征选择

gini计算系数函数。
按照特征进行dataframe划分函数。
在给定特征的条件概率下对其他特征进行遍历,找到gini系数最小的特征,并返回gini系数值以及按照此特征重新划分后的子空间。

#读数据
data=pd.read_csv('example_data.csv')
#pandas-framework
#  humility   outlook  temp  windy play
# 0      high     sunny   hot  False   no
# 1      high     sunny   hot   True   no
# 2      high  overcast   hot  False  yes
# 3      high     rainy  mild  False  yes
# 4    normal     rainy  cool  False  yes
# 5    normal     rainy  cool   True   no
# 6    normal  overcast  cool   True  yes
# 7      high     sunny  mild  False   no
# 8    normal     sunny  cool  False  yes
# 9    normal     rainy  mild  False  yes
# 10   normal     sunny  mild   True  yes
# 11     high  overcast  mild   True  yes
# 12   normal  overcast   hot  False  yes
# 13     high     rainy  mild   True   no
#gini系数计算
def gini(a):
#计算每一个数出现的概率
# probs=[]
# for i in set(a):
#     probs.append(count(i)/len(a))
probs=[a.count(i)/len(a) for i in set(a)]
gini=sum([p*(1-p) for p in probs])
return gini
# #test for function gini
# a=['a', 'b', 'c', 'd', 'b', 'c', 'a', 'b', 'c', 'd', 'a']
# print(gini(a))#0.743801652892562
#按特征进行划分
def  spilt_framework(data,col):
unique_value=data[col].unique()#找到同一特征下的不同值,并放入对应不同的列表
result_dict={elem:pd.DataFrame for elem in unique_value}

#根据关键字划分成两个不同的dataframe
for k in result_dict.keys():
result_dict[k]=data[:][data[col]==k]
return result_dict
#test
# print(spilt_framework(data,'humility'))
# {'high':    humility   outlook  temp  windy play
# 0      high     sunny   hot  False   no
# 1      high     sunny   hot   True   no
# 2      high  overcast   hot  False  yes
# 3      high     rainy  mild  False  yes
# 7      high     sunny  mild  False   no
# 11     high  overcast  mild   True  yes
# 13     high     rainy  mild   True   no, 'normal':    humility   outlook  temp  windy play
# 4    normal     rainy  cool  False  yes
# 5    normal     rainy  cool   True   no
# 6    normal  overcast  cool   True  yes
# 8    normal     sunny  cool  False  yes
# 9    normal     rainy  mild  False  yes
# 10   normal     sunny  mild   True  yes
# 12   normal  overcast   hot  False  yes}
#对每个子空间计算每个特征对应下的gini数,数值最小即为最优的特征
#返回最小gini值,特征,划分好的空间值
def choose_best_col(data,label):
#计算gini系数
gini_D=gini(data[label].tolist())
#每个子空间中再根据最优特征进行划分
cols=[col for col in data.columns if col not in [label]]
min_value,best_col=999,None
min_splited=None
for col in cols:
#按每一个特征进行划分
splited_set=spilt_framework(data,col)
gain_DA=0
#大特征给定的条件概率下,计算每个特征的gini
for subset_col,subset in splited_set.items():
gini_Di=gini(subset[label].tolist())#计算每个子块的gini
#计算给定特征的条件gini
gini_DA=len(subset)/len(data)*gini_Di

if gini_DA<min_value:
min_value,best_col=gini_DA,col
min_splited=splited_set
return min_value,best_col,min_splited
# example:
# print(choose_best_col(data,'play'))
# (0.12244897959183676, 'humility', {'high':    humility   outlook  temp  windy play
# 0      high     sunny   hot  False   no
# 1      high     sunny   hot   True   no
# 2      high  overcast   hot  False  yes
# 3      high     rainy  mild  False  yes
# 7      high     sunny  mild  False   no
# 11     high  overcast  mild   True  yes
# 13     high     rainy  mild   True   no, 'normal':    humility   outlook  temp  windy play
# 4    normal     rainy  cool  False  yes
# 5    normal     rainy  cool   True   no
# 6    normal  overcast  cool   True  yes
# 8    normal     sunny  cool  False  yes
# 9    normal     rainy  mild  False  yes
# 10   normal     sunny  mild   True  yes
# 12   normal  overcast   hot  False  yes})
##生成决策树
class CartTree:
#定义节点,初始化;连接
class Node:
def __init__(self, name):
self.name = name
self.connections = {}

def connect(self, label, node):
self.connections[label] = node
#初始化
def __init__(self, data, label):
self.columns = data.columns
self.data = data
self.label = label
self.root = self.Node("Root")

def print_tree(self, node, tabs):
print(tabs + node.name)#打印特征
for connection, child_node in node.connections.items():
print(tabs +"\t"+ "("+ str(connection) + ")")#打印特征值,\t为横向制表符,与前行留空格
self.print_tree(child_node, tabs + "\t\t")#打印子节点

def construct_tree(self):
self.construct(self.root, "", self.data, self.columns)

def construct(self, parent_node, parent_connection_label, input_data, columns):
min_value, best_col, min_splited = choose_best_col(input_data[columns], self.label)#

if not best_col:
node = self.Node(input_data[self.label].iloc[0])#iloc[0]选取第一列
parent_node.connect(parent_connection_label, node)
return

node = self.Node(best_col)
parent_node.connect(parent_connection_label, node)

new_columns = [col for col in columns if col != best_col]#层层筛选过的特征作为条件概率
#递归生成决策树
for splited_value, splited_data in min_splited.items():
self.construct(node, splited_value, splited_data, new_columns)

主程序

data=pd.read_csv('example_data.csv')
tree=CartTree(data,'play')
tree.construct_tree()
tree.print_tree(tree.root, " ")

结果:

Root
()
temp
(hot)
humility
(high)
outlook
(sunny)
windy
(False)
no
(True)
no
(overcast)
windy
(False)
yes
(normal)
outlook
(overcast)
windy
(False)
yes
(mild)
humility
(high)
outlook
(rainy)
windy
(False)
yes
(True)
no
(sunny)
windy
(False)
no
(overcast)
windy
(True)
yes
(normal)
outlook
(rainy)
windy
(False)
yes
(sunny)
windy
(True)
yes
(cool)
outlook
(rainy)
windy
(False)
humility
(normal)
yes
(True)
humility
(normal)
no
(overcast)
humility
(normal)
windy
(True)
yes
(sunny)
humility
(normal)
windy
(False)
yes

Process finished with exit code 0
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: