您的位置:首页 > 数据库 > Mongodb

贝叶斯分类方法学习三 python+jieba+mongodb实现朴素贝叶斯新闻文本自动分类

2017-04-16 11:03 1121 查看

首先看一下工程目录:



mongo_client目录下存放的是mongodb数据库的连接,以及数据的获取
navie_bayes目录下存放的是朴素贝叶斯的实现
tags_posterior目录下存放的是已经计算好样本的标签的后验概率
tags_priori 目录下存放的是计算好的样本标签先验概率
training目录下存放的样本训练的方法
接下来简单的介绍一下各目录下的python实现:


mongo_client/mongodb_client.py

class MongodbClient(object):
"""
connect mongodb
"""
def __init__(self):
db_client = MongoClient("localhost", 27017)
db = db_client.toutiao
self.collection = db.sample

def find_all_db_tags(self):
"""
collect all distinct tags_posterior in mongodb
:return:
"""
tags_list = self.collection.distinct('tag')

if len(tags_list) == 0:
print 'no tags_posterior in db, please checkout the connect of mongodb'

return tags_list

def find_all_dir_tags(self):
"""
to find tags_posterior have existed in tags_posterior dir
the format of tags_posterior like 'tag.txt' not 'tag'
:return:
"""
tags_list = os.listdir('tags_posterior/')

return tags_list

def create_tag_file(self):
"""
create all kind of tags_posterior in this path tags_posterior/
:return:
"""
if not os.path.isdir('tags_posterior'):
os.makedirs('tags_posterior')

tags_list = self.find_all_db_tags()

if len(tags_list) > 0:
for tag in tags_list:
filename = 'tags_posterior/' + tag.encode('utf-8') + '.txt'  # tag is unicode, but we declare all file is utf-8

if os.path.exists(filename):
continue
else:
f = open(filename, 'wb')

def find_all_articles(self):
"""
search all key-value dic result from mongodb
:return:
"""
article_list = self.collection.find({'content': {'$exists': True, '$ne': '无'}})
if article_list.count == 0:
print 'no data in database, please checkout the database'
else:
return article_list

def create_tags_probability(self):

if not os.path.isdir('tags_priori'):
os.makedirs('tags_priori')

tags_list = self.find_all_dir_tags()
for tag in tags_list:
# tag is like 'tag.txt',we need cut
tag_name = tag[:-4]
tag_count = self.collection.find({'tag': tag_name,
'content': {'$exists': True, '$ne': '无'},
}).count()

# the probability of tag in database
probability = float(tag_count) / float(self.find_all_articles().count())

line = tag_name + ' ' + repr(probability) + ' ' + repr(tag_count) + '\n'

probability_file = open('tags_priori/priori.txt', 'ab')

probability_file.write(line)

pass

def find_tags_p_probability(self, tag):

tag_posterior_file = open('tags_priori/priori.txt', 'rb')

for line in tag_posterior_file:
tags_list = line.split(' ')
if tags_list[0] == tag:
return float(tags_list[1])
return 0


navie_bayes/naive_bayes_classifier.py

import jieba
import jieba.analyse
from mongo_client.mongodb_client import MongodbClient

class NBClassifier(object):

def __init__(self):
self.db_client = MongodbClient()
pass

def classify_article(self, article):

extract_keywords = jieba.analyse.extract_tags(article,
topK=10)

tags_list = self.db_client.find_all_dir_tags()

# save probability of one tag
naive_bayes_probability = list()

# find all tags in tags_posterior to calculate priori probability
for tag in tags_list:

tags_file = open('tags_posterior/' + tag, 'r')

p = list()

for line in tags_file:
# split tag's key words
words_list = line.decode('utf-8').encode('utf-8').split(' ')
# save keywords in tag's file

ebb3
w_list = list()
# save probability for keywords in tag's file
p_list = list()

for index, word in enumerate(words_list):
# the first word is group id, which not useful
if index == 0:
continue
elif index % 2 == 1:
w_list.append(word)
else:
p_list.append(word.strip('\n'))

p.append(self.calculate_tag_posterior(extract_keywords, w_list, p_list))

# calculate max probability with product posteriors and priori
naive_bayes_probability.append(max(p) * self.calculate_tag_priori(tag[:-4]))

# MAP(最大后验概率)
max_probability = max(naive_bayes_probability)
max_index = naive_bayes_probability.index(max_probability)
print naive_bayes_probability
print 'MAP:', max_probability, '该文章属于:', tags_list[max_index][:-4]
print '本来属于:', tags_list[10], naive_bayes_probability[10]

pass

def calculate_tag_priori(self, tag):
return self.db_client.find_tags_p_probability(tag)

def calculate_tag_posterior(self, keywords, w_list, p_list):
"""
here we use polynomial to calculate posterior
:param keywords: the words of test article
:param w_list: the keywords of sample's article
:param p_list: the probability of sample's article keywords
:return:product of tag's posterior(后验概率乘积)

!!!!
here you should pay attention for length of keywords and length of w_list
sometime they not the same length
"""

p = 1.0
c = 100   # suppose all of article's keywords is 100 number

for index, word in enumerate(keywords):

# type of str not to encode, unicode need to encode
if isinstance(word, unicode):
word = word.encode('utf-8')

if word in w_list:
# use polynomial algorithm to handle smooth (平滑处理)
p *= float((float(p_list[w_list.index(word)]) * c + 1)) / float(c + 2)

else:
p *= float((0.0 * c + 1)) / float(c + 2)
return p
pass


training/navie_bayes_training.py

import os
import jieba
import jieba.analyse

from mongo_client.mongodb_client import MongodbClient

class NBTraining(object):
def __init__(self):
self.db_client = MongodbClient()
pass

def create_tags_list(self):
self.db_client.create_tag_file()

def find_all_tags_list(self):
return self.db_client.find_all_dir_tags()

def find_all_articles_list(self):
return self.db_client.find_all_articles()

def add_stop_word(self):
"""
remove useless chinese words from article
this keywords txt is user-defined
:return:
"""
jieba.analyse.set_stop_words('stop_words.txt')
pass

def open_parallel_analyse(self, thread_count):
"""
open multi thread processing
:param thread_count:thread number
:return:
"""
jieba.enable_parallel(thread_count)
pass

def tf_idf_analyze_article(self):
"""
use TF-IDF model to analyse article to extract keywords
:return:
"""
article_list = self.find_all_articles_list()
tags_list = self.find_all_tags_list()

for article in article_list:

# only have content and tag needed to analyse
# here exist some situation, we need consider

if 'content' in article and 'tag' in article:
if article['content'] != u'无' or article['content'] != '':

article_name = (article['tag'] + '.txt').encode('utf-8')
tag_path = 'tags_posterior/' + article_name

# distinct the group id that have insert into tag file
group_id_list = self.find_all_tag_group_id(tag_path)
if self.exist_group_id(group_id_list, repr(article['group_id'])):
continue

# if the article dose not analyse, next extract the key words
# first, add user-defined chinese stop words
self.add_stop_word()

# second, extract at least 10 key words with weight
if article_name in tags_list:

# analyse content = title + content
content = (article['title'] + article['content']).encode('utf-8')

# start 4 threads in parallel
self.open_parallel_analyse(4)

extract_keywords = jieba.analyse.extract_tags(content,
topK=10,
withWeight=True)
article_keywords = list()

# group id is long type that need to translate into str
article_keywords.append(repr(article['group_id']))

for keyword in extract_keywords:
# word
article_keywords.append(keyword[0].encode('utf-8'))
# weight
article_keywords.append(repr(keyword[1]))

article_keywords_line = ' '.join(str(word) for word in article_keywords)
# print article_name
# print 'keywords line:', article_keywords_line
tags_file = open(tag_path, 'ab')
tags_file.write(article_keywords_line + '\n')

else:
print 'no tags_posterior for this ' + article_name
pass

def find_all_tag_group_id(self, path):
"""
get all group id in tags_posterior dir
:param path:
:return:group id list
"""
if not os.path.exists(path):
print 'the path for group id is not exist'
return

tags_file = open(path, 'rb')

group_id_list = list()

for line in tags_file:
words_list = line.split(' ')

if len(words_list) > 0:
group_id_list.append(words_list[0])
# print group_id_list

return group_id_list
pass

def exist_group_id(self, group_id_list, group_id):
"""
distinct the group id, instead of insert keywords line repeatedly
:param group_id_list:
:param group_id:
:return: True or False
"""
if group_id in group_id_list:
return True
else:
return False

pass

def clear_all_tags_file(self):
"""
remove all keywords in tags_posterior file,
manual clear
:return:
"""
tags_list = self.find_all_tags_list()
for tag in tags_list:
tag_file = open('tags_posterior/' + tag, 'wb')
tag_file.truncate()


training/stop_words.txt

停用词:意义不大的词,不需要统计的词语



tags_posterior/



tags_priori/priori.txt

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐