您的位置:首页 > 编程语言 > Python开发


2018-02-13 11:20 846 查看
from sklearn.datasets import fetch_20newsgroups
news = fetch_20newsgroups(subset='all')
from sklearn.cross_validation import train_test_split
from sklearn.feature_extraction.text import CountVectorizer,TfidfVectorizer

count_filter_vec = CountVectorizer(analyzer='word',stop_words='english')

tfidf_filter_vec = TfidfVectorizer(analyzer='word',stop_words='english')

X_count_filter_train = count_filter_vec.fit_transform(X_train)
X_count_filter_test = count_filter_vec.transform(X_test)

X_tfidf_filter_train = tfidf_filter_vec.fit_transform(X_train)
X_tfidf_filter_test= tfidf_filter_vec.transform(X_test)

from sklearn.naive_bayes import MultinomialNB
mnb_count = MultinomialNB()


print('The accuracy of  Classifier 20newsgroups using Native Bayes (countvectorizer by filtering stopwords: ',mnb_count.score(X_count_filter_test,y_test))

y_count_filter_predict = mnb_count.predict(X_count_filter_test)


print('The accuracy of  Classifier 20newsgroups using Native Bayes (Tfidf by filtering stopwords: ',mnb_count.score(X_tfidf_filter_test,y_test))

y_Tfidf_filter_predict = mnb_count.predict(X_tfidf_filter_test)

from sklearn.metrics import classification_report

print(classification_report(y_test,y_count_filter_predict,target_names = news.target_names))

from matplotlib import pyplot as plt
import numpy as np

def show_values(pc, fmt="%.2f", **kw):
Heatmap with text in each cell with matplotlib's pyplot
Source: https://stackoverflow.com/a/25074150/395857 By HYRY
global zip
import  itertools
zip = getattr(itertools, 'izip', zip)
ax = pc.axes
for p, color, value in  zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
x, y = p.vertices[:-2, :].mean(0)
if np.all(color[:3] > 0.5):
color = (0.0, 0.0, 0.0)
color = (1.0, 1.0, 1.0)
ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)

def cm2inch(*tupl):
Specify figure size in centimeter in matplotlib
Source: https://stackoverflow.com/a/22787457/395857 By gns-ank
inch = 2.54
if type(tupl[0]) == tuple:
return tuple(i/inch for i in tupl[0])
return tuple(i/inch for i in tupl)

def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels, figure_width=40, figure_height=20, correct_orientation=False, cmap='RdBu'):
Inspired by:
- https://stackoverflow.com/a/16124677/395857 - https://stackoverflow.com/a/25074150/395857 '''

# Plot it out
fig, ax = plt.subplots()
#c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)
c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap=cmap)

# put the major ticks at the middle of each cell
ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)
ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)

# set tick labels
#ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)
ax.set_xticklabels(xticklabels, minor=False)
ax.set_yticklabels(yticklabels, minor=False)

# set title and x/y labels

# Remove last blank column
plt.xlim( (0, AUC.shape[1]) )

# Turn off all the ticks
ax = plt.gca()
for t in ax.xaxis.get_major_ticks():
t.tick1On = False
t.tick2On = False
for t in ax.yaxis.get_major_ticks():
t.tick1On = False
t.tick2On = False

# Add color bar

# Add text in each cell

# Proper orientation (origin at the top left instead of bottom left)
if correct_orientation:

# resize
fig = plt.gcf()
#fig.set_size_inches(cm2inch(40, 20))
#fig.set_size_inches(cm2inch(40*4, 20*4))
fig.set_size_inches(cm2inch(figure_width, figure_height))

def plot_classification_report(classification_report, title='Classification report ', cmap='RdBu'):
Plot scikit-learn classification report.
Extension based on https://stackoverflow.com/a/31689645/395857 '''
lines = classification_report.split('\n')

classes = []
plotMat = []
support = []
class_names = []
for line in lines[2 : (len(lines) - 2)]:
t = line.strip().split()
if len(t) < 2: continue
v = [float(x) for x in t[1: len(t) - 1]]

print('plotMat: {0}'.format(plotMat))
print('support: {0}'.format(support))

xlabel = 'Metrics'
ylabel = 'Classes'
xticklabels = ['Precision', 'Recall', 'F1-score']
yticklabels = ['{0} ({1})'.format(class_names[idx], sup) for idx, sup  in enumerate(support)]
figure_width = 25
figure_height = len(class_names) + 7
correct_orientation = False
heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap)

def main():
sampleClassificationReport =classification_report(y_test,y_count_filter_predict,target_names = news.target_names)
plt.savefig('countvector_plot_classif_report.png', dpi=200, format='png', bbox_inches='tight')

sampleClassificationReport1 =classification_report(y_test,y_Tfidf_filter_predict,target_names = news.target_names)
plt.savefig('tfidfvector_plot_classif_report1.png', dpi=200, format='png', bbox_inches='tight')

if __name__ == "__main__":
#cProfile.run('main()') # if you want to do some profiling

print(classification_report(y_test,y_Tfidf_filter_predict,target_names = news.target_names))

Traceback (most recent call last):
File "D:\Python35\demo\count_vectorizer_stop.py", line 189, in <module>
File "D:\Python35\demo\count_vectorizer_stop.py", line 179, in main
File "D:\Python35\demo\count_vectorizer_stop.py", line 174, in plot_classification_report
heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap)
NameError: name 'np' is not defined

导入相应的架包 numpy
运行结果如下: precision recall f1-score support

alt.atheism 0.85 0.89 0.87 201
comp.graphics 0.62 0.88 0.73 250
comp.os.ms-windows.misc 0.93 0.22 0.36 248
comp.sys.ibm.pc.hardware 0.62 0.88 0.73 240
comp.sys.mac.hardware 0.93 0.85 0.89 242
comp.windows.x 0.82 0.85 0.84 263
misc.forsale 0.90 0.79 0.84 257
rec.autos 0.91 0.91 0.91 238
rec.motorcycles 0.98 0.94 0.96 276
rec.sport.baseball 0.98 0.92 0.95 251
rec.sport.hockey 0.92 0.99 0.95 233
sci.crypt 0.91 0.97 0.93 238
sci.electronics 0.87 0.89 0.88 249
sci.med 0.94 0.95 0.95 245
sci.space 0.91 0.96 0.93 221
soc.religion.christian 0.87 0.94 0.90 232
talk.politics.guns 0.89 0.96 0.93 251
talk.politics.mideast 0.95 0.98 0.97 231
talk.politics.misc 0.84 0.90 0.87 188
talk.religion.misc 0.91 0.53 0.67 158

avg / total 0.88 0.86 0.85 4712
precision    recall  f1-score   support

alt.atheism       0.86      0.81      0.83       201
comp.graphics       0.85      0.81      0.83       250
comp.os.ms-windows.misc       0.84      0.87      0.86       248
comp.sys.ibm.pc.hardware       0.78      0.88      0.83       240
comp.sys.mac.hardware       0.92      0.90      0.91       242
comp.windows.x       0.95      0.88      0.91       263
misc.forsale       0.90      0.80      0.85       257
rec.autos       0.89      0.92      0.90       238
rec.motorcycles       0.98      0.94      0.96       276
rec.sport.baseball       0.97      0.93      0.95       251
rec.sport.hockey       0.88      0.99      0.93       233
sci.crypt       0.85      0.98      0.91       238
sci.electronics       0.93      0.86      0.89       249
sci.med       0.96      0.93      0.95       245
sci.space       0.90      0.97      0.93       221
soc.religion.christian       0.70      0.96      0.81       232
talk.politics.guns       0.84      0.98      0.90       251
talk.politics.mideast       0.92      0.99      0.95       231
talk.politics.misc       0.97      0.74      0.84       188
talk.religion.misc       0.96      0.29      0.45       158

avg / total       0.89      0.88      0.88      4712

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息