您的位置:首页 > 其它

训练中文分词HMM模型,得到A(状态转移矩阵)、B(混淆矩阵)、Pi(初始状态概率)

2016-06-08 08:31 1041 查看
#!F://python
# page coding=utf-8
#状态转移矩阵:状态之间的转移概率      B(混淆矩阵):状态为y的情况下,输出为x的概率  Pi:初始状态为s的状态
# 本代码用来训练中文分词HMM模型,得到A矩阵(状态转移矩阵)、B矩阵(混淆矩阵)、Pi向量(初始概率向量)并且用pickle 将他们的utf-8码写到了文件当中去
import pickle
import codecs
A_dic = {}    # 状态转移矩阵
B_dic = {}    # 混淆矩阵
Pi_dic = {}   # 状态初始概率向量
state_list = ['B', 'M', 'E', 'S']
state_M = 4
word_N = 0
Pi_dic_size = 0.0
A_row_count_dic = {}  # 记录A[key]所在行的所有列的总和,也就是在求解A_dic每行的每一个值候(A[key1][key2]),分母就是Count_dic[key]
B_dic_element_size = {}  # 这个词典中记录的是,混淆矩阵中的对应着同一个输出字符的隐藏状态的数量
PROB_SATRT = "prob_start.py"   # 初始概率向量写在这里
INPUT_DATA = "RenMinData.txt"  # 用于训练的语料
PROB_EMIT = "prob_emit.py"     # 混淆矩阵
PROB_TRANS = "prob_trans.py"   # 状态转移矩阵

def init():
global Pi_dic
global B_dic
global A_dic
global A_row_count_dic
for state in state_list:
A_dic[state] = {}
for state1 in state_list:
A_dic[state][state1] = 0.0
for state in state_list:
Pi_dic[state] = 0.0  # 初始化状态初始概率向量中所有的元素值为0.0
B_dic[state] = {}    # 初始化混淆矩阵中所有的元素对应为一个词典
A_row_count_dic[state] = 0
# print B_dic,"over"

def getList(input_str):     # 生成”山棱“=》“BE” ”君“=》“S” "才敢与君绝"=》“BMMME”
output_str = []
input_str_length = len(input_str)
if input_str_length == 1:
output_str.append('S')
else:
Middle_Num = input_str_length-2
output_str.append('B')
output_str.extend(['M']*Middle_Num)  # list.extend(list)直接将一个list合并到另一个list中
output_str.append('E')
return output_str

def main(train_file_path):
init()
global Pi_dic_size
global word_N
global B_dic_element_size
global A_row_count_dic
global B_dic
train_file = codecs.open("copy.txt", "rb", "utf8")
train_file.read(1)       # utf-8文件读取时候第一个字符的位置是不可见字符,所以要踢掉
for line in train_file:  # line=山无棱  天地合
if not line:
continue
word_list = line.split(" ")  # word_list:['山无棱',‘天地合’]
line_state=[]
for word in word_list:
line_state.append(getList(word))  # line_state:[[BME],[BME]]
print line_state
if len(line_state)!=len(word_list):
print "different length for a word and the corresponding state"
return
for i in range(len(line_state)):  # 这里的i是第几个[B,M,E]
Pi_dic[line_state[i][0]] += 1     # 构建初始向量
Pi_dic_size += len(line_state)
for j in range(len(word_list[i])):  # word_list[i]:山无棱   #line_state[i]:['B','M','E']
if word_list[i][j] not in B_dic[line_state[i][j]]:
utfWordList=word_list[i][j].encode('utf-8')
B_dic[line_state[i][j]][utfWordList] = 1.0  # 构建混淆矩阵:line_state:word_list
else:
B_dic[line_state[i][j]][word_list[i][j].encode('utf-8')] += 1
if word_list[i][j] not in B_dic_element_size:
B_dic_element_size[word_list[i][j].encode('utf-8')] = 1
else:
B_dic_element_size[word_list[i][j].encode('utf-8')] += 1
if j < len(line_state[i]) - 1:
A_dic[line_state[i][j]][line_state[i][j + 1]] += 1  # 状态转移向量
A_row_count_dic[line_state[i][j]] += 1
print B_dic_element_size
train_file.close()
probs()

def probs():
PROB_SATRT = "prob_start.py"  # 初始概率向量写在这里
INPUT_DATA = "RenMinData.txt"  # 用于训练的语料
PROB_EMIT = "prob_emit.py"  # 混淆矩阵
PROB_TRANS = "prob_trans.py"  # 状态转移矩阵
global Pi_dic
global Pi_dic_size
global B_dic
global A_dic
global B_dic_element_size
global A_row_count_dic
start_fp = open(PROB_SATRT, 'w')
emit_fp = open(PROB_EMIT, 'w')
trans_fp = open(PROB_TRANS, 'w')
print "-------------------以下Pi向量------------------------"
for key in Pi_dic:
Pi_dic[key] = Pi_dic[key] / Pi_dic_size
print Pi_dic
print "-------------------以下是状态转移矩阵------------------------"
for key in A_dic:
for key2 in A_dic[key]:
if A_row_count_dic[key] != 0:
A_dic[key][key2] = A_dic[key][key2]/A_row_count_dic[key]
print A_dic
print "------------------以下是混淆矩阵-----------------------"
for key in B_dic:
for key1 in B_dic[key]:
B_dic[key][key1]=B_dic[key][key1]/B_dic_element_size[key1]
for item in B_dic:
for key in B_dic[item] :
print item, '-->', key,B_dic[item][key],'   ',
# for key in B_dic:

pickle.dump(A_dic, start_fp)
pickle.dump(B_dic, emit_fp)
pickle.dump(Pi_dic, trans_fp)

# for key in Pi_dic:
#     start_fp.write(key+" : "+str(Pi_dic[key]).encode('utf-8'))
#     start_fp.write("\n")
# for key in B_dic:
#     # emit_fp.write()
#     emit_fp.write(key+" : ")
#     for key2 in B_dic[key]:
#         emit_fp.write(key2.encode('utf-8'))
#     emit_fp.write("\n")
# for key in A_dic:
#     trans_fp.write(key+" : "+str(A_dic[key]).encode('utf-8'))
#     trans_fp.write("\n")
start_fp.close()
emit_fp.close()
trans_fp.close()
main("RenMinData.txt")
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  HMM 中文分词