您的位置:首页 > 编程语言 > Python开发

python 决策树实现案例

2017-11-26 21:30 204 查看
根据培训班课程写的

在txt文件中写如下:

RID,age,income,student,credit_rating,Class_buys_computer

1,youth,high,no,fair,no

2,youth,high,no,excellent,no

3,middle_aged,high,no,fair,yes

4,senior,medium,no,fair,yes

5,senior,low,yes,fair,yes

6,senior,low,yes,excellent,no

7,middle_aged,low,y

#sklearn only allow Integer,DicVectorizer help to transe to Integer

from sklearn.feature_extraction import DictVectorizer

import csv

from sklearn import preprocessing

from sklearn import tree

#readind or writing will be used

from sklearn.externals.six import StringIO

allElectronicsData = open(r'E:\python_excel\jueceshu.txt','rt')#why does it write rb to rt? 

reader = csv.reader(allElectronicsData)

headers = next(reader)#reader.next()

print(headers)

featureList = []

labelList = []

#kk = 0;

for row in reader:

    labelList.append(row[len(row) - 1])

    #print(row[len(row) - 1])

#    kk+=1

#    print(row)

    rowDict = {}

    for i in range(1,len(row) - 1):

        rowDict[headers[i]] = row[i]

    featureList.append(rowDict)

#print("dddddddddddddddddddddd")

print(featureList)

#print(kk)

vec = DictVectorizer()

#array to matrix

dummyX = vec.fit_transform(featureList).toarray()

print("dummyX:"+str(dummyX))

#get featrue name and value of featrue

print(vec.get_feature_names())

print("labelList:"+str(labelList))

lb = preprocessing.LabelBinarizer()

dummyY = lb.fit_transform(labelList)

print("dummyY:"+str(dummyY))

clf = tree.DecisionTreeClassifier(criterion='entropy')

clf = clf.fit(dummyX,dummyY)

print("clf:"+str(clf))

with open("allElectronicInformationGainOri.dot",'w') as f:

    f = tree.export_graphviz(clf, feature_names=vec.get_feature_names(), out_file = f )

    

oneRowX = dummyX[0,:]

print("oneRowX:"+str(oneRowX))

#newRowX = oneRowX

newRowX = [dummyX[0]]

#newRowX[0] = 1

#newRowX[2] = 0

newRowX[0][0] = 1

newRowX[0][2] = 0

print("newRowX:"+str(newRowX))

#predictedY = clf.predict_log_proba(newRowX)

predictedY = clf.predict(newRowX)#newRowX)

print("predictedY:"+str(predictedY))
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: