您的位置:首页 > 其它

bert得到向量

2020-07-02 16:23 169 查看

modeling,tokenization为bert在github上的代码,链接: https://github.com/google-research/bert.

chinese_L-12_H-768_A-12是中文语料训练的模型,链接: https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip.

from bert_demo import modeling
from bert_demo import tokenization
import numpy as np
import tensorflow as tf

class bert_vec():
def __init__(self):
# graph
self.input_ids = tf.placeholder(tf.int32, shape=[None, None], name='input_ids')
self.input_mask = tf.placeholder(tf.int32, shape=[None, None], name='input_masks')
self.segment_ids = tf.placeholder(tf.int32, shape=[None, None], name='segment_ids')

bert_config = modeling.BertConfig.from_json_file('chinese_L-12_H-768_A-12/bert_config.json')
# 初始化BERT
self.model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=self.input_ids,
input_mask=self.input_mask,
token_type_ids=self.segment_ids,
use_one_hot_embeddings=False
)
# bert模型地址
init_checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt"
# 模型的训练参数
tvars = tf.trainable_variables()
# 加载模型
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())

def get_embedding(self,char_lists,mask_lists,seg_lists):
# 获取最后一层和倒数第二层
encoder_last_layer = self.model.get_sequence_output()
# encoder_last2_layer = model.all_encoder_layers[-2]

feed_data = {self.input_ids: np.asarray(char_lists), self.input_mask: np.asarray(mask_lists), self.segment_ids: np.asarray(seg_lists)}
embedding = self.sess.run(encoder_last_layer, feed_dict=feed_data)
return embedding

if __name__ == '__main__':

#数据处理
string = '设置一个随机种子'
char_list = ['[CLS]'] + list(string) +['[SEP]']
#不做masked处理
mask_list = [1] * (len(string)+2)
#不做分词处理
seg_list = [0] * (len(string)+2)

# 根据bert的词表做一个char_to_id的操作
# 未登录词会报错,更改报错代码使未登录词时为'[UNK]'
# 也可以自己实现
token = tokenization.FullTokenizer(vocab_file='chinese_L-12_H-768_A-12/vocab.txt')
char_list = token.convert_tokens_to_ids(char_list)
bertVec = bert_vec()

#得到bert的embedding
embedding = bertVec.get_embedding([char_list], [mask_list], [seg_list])
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: