您的位置:首页 > 编程语言 > Python开发

LR识别垃圾短信,详解python函数

2017-11-01 22:51 399 查看
import array
import collections
import itertools
import operator
import jieba
import sklearn
import sklearn.linear_model as linear_model
import sys

#将原始数据拆分成训练集和测试集
def fetch_train_test(data_path,test_size=0.2):
y = []
text_list = []
for line in open(data_path,'r',encoding='utf8').readlines():
label,text = line[:-1].split('\t',1)
y.append(int(label))
text_list.append(list(jieba.cut(text)))
return sklearn.model_selection.train_test_split(text_list,y,test_size=test_size,random_state=1028)

#创建字典
def build_dict(text_list,min_freq=5):
freq_dict = collections.Counter(itertools.chain(*text_list))
freq_list = sorted(freq_dict.items(),key=operator.itemgetter(1),reverse=True)
words,_ = zip(*filter(lambda wc:wc[1]>=min_freq,freq_list))
return  dict(zip(words,range(len(words))))

#抽取特征
def text2vect(text_list,word2id):
X = []
for text in text_list:
vect = array.array('l',[0]*len(word2id))
for word in text:
if word not in word2id:
continue
vect[word2id[word]] = 1
X.append(vect)
return X

#模型评估
def evaluate(model,X,y):
accuracy = model.score(X, y)
fpr,tpr,thresholds = sklearn.metrics.roc_curve(y, lr.predict_proba(X)[:, 1], pos_label=1)
return accuracy, sklearn.metrics.auc(fpr, tpr)

if __name__ == '__main__':
X_train, X_test, y_train, y_test = fetch_train_test('F:\\train.txt')
word2id = build_dict(X_train,min_freq=10)
X_train = text2vect(X_train,word2id)
X_test = text2vect(X_test,word2id)
lr = linear_model.LogisticRegression(C=1)
lr.fit(X_train,y_train)

accuracy,auc = evaluate(lr,X_train,y_train)
sys.stdout.write('训练集正确率:%.4f%%\n'%(accuracy*100))
sys.stdout.write("训练集AUC值:%.6f\n" % (auc))

accuracy, auc = evaluate(lr, X_test, y_test)
sys.stdout.write("测试集正确率:%.4f%%\n" % (accuracy * 100))
sys.stdout.write("测试AUC值:%.6f\n" % (auc))

对于build_dict()函数,可能对于最初接触python的同学不太友好,我们先看intertools.chain(*text_list)实现什么操作

>>> a=[['a','b','c'],['b','c','d']]
>>> import collections
>>> import itertools
>>> itertools.chain(*a)
<itertools.chain object at 0x00000000035E4BA8>
>>> list(itertools.chain(*a))
['a', 'b', 'c', 'b', 'c', 'd']可以看到intertools.chain()函数将传入的二维列表变成一维,包含子列表中每一个元素,并返回一个迭代器类型。a前面的*必须要,‘’*‘’可以理解为取a中的元素,如果不加*即对整个a作用。
>>> freq_dict=collections.Counter(list(itertools.chain(*a)))
>>> freq_dict
Counter({'b': 2, 'c': 2, 'a': 1, 'd': 1})collections.Counter()函数将传入的列表整理出重复出现的元素并计算出现次数,返回一个字典。注意可以不用对itertools.chain(*a)执行列表化操作,答案也是对的。
>>> c=filter(lambda wc:wc[1]>1,freq_dict.items())
>>> c
<filter object at 0x00000000025F5320>
filter函数筛选出出现次数>1的,注意传入的参数应列表化,即执行.items()操作,而不能直接传入一个字典。filter返回一个迭代器类型。
>>> freq_dict.items()
dict_items([('a', 1), ('b', 2), ('c', 2), ('d', 1)])
可以看到.items()的操作
>>> c=filter(lambda wc:wc[1]>1,freq_dict.items())
>>> words,_ = zip(*c)
>>> words
('b', 'c')
>>> _
(2, 2)
同样,这里的*号必不可少,它代表读取c中的元素,若不加*号,则会出现下面的结果
>>> c=filter(lambda wc:wc[1]>1,freq_dict.items())
>>> words,_ = zip(c)
>>> words
(('b', 2),)
它会默认你在用c这个整体,和一个空的列表做zip操作。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息