您的位置:首页 > 其它

Tensorflow lstm实现的小说撰写预测

2017-03-10 14:48 477 查看
最近,在研究深度学习方面的知识,结合Tensorflow,完成了基于lstm的小说预测程序demo。

lstm是改进的RNN,具有长期记忆功能,相对于RNN,增加了多个门来控制输入与输出。原理方面的知识网上很多,在此,我只是将我短暂学习的tensorflow写一个预测小说的demo,如果有错误,还望大家指出。

1、将小说进行分词,去除空格,建立词汇表与id的字典,生成初始输入模型的x与y

def readfile(file_path):

    f = codecs.open(file_path, 'r', 'utf-8')

    alltext = f.read()

    alltext = re.sub(r'\s','', alltext)

    seglist = list(jieba.cut(alltext, cut_all = False))

    return seglist

    

def _build_vocab(filename):

    data = readfile(filename)

    counter = collections.Counter(data)

    count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))

    words, _ = list(zip(*count_pairs))

    word_to_id = dict(zip(words, range(len(words))))

    id_to_word = dict(zip(range(len(words)),words))

    dataids = []

    for w in data:

        dataids.append(word_to_id[w])

    return word_to_id, id_to_word,dataids

def dataproducer(batch_size, num_steps):

    word_to_id, id_to_word, data = _build_vocab('F:\\ml\\code\\lstm\\1.txt')

    datalen = len(data)

    batchlen = datalen//batch_size

    epcho_size = (batchlen - 1)//num_steps

    data = tf.reshape(data[0: batchlen*batch_size], [batch_size,batchlen])

    i = tf.train.range_input_producer(epcho_size, shuffle=False).dequeue()

    x = tf.slice(data, [0,i*num_steps],[batch_size, num_steps])

    y = tf.slice(data, [0,i*num_steps+1],[batch_size, num_steps])

    x.set_shape([batch_size, num_steps])

    y.set_shape([batch_size, num_steps])

    return x,y,id_to_word

2、建立lstm模型:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias = 0.5)

lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob = keep_prob)

cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell], num_layers)

3、根据训练数据输出误差反向调整模型

with tf.variable_scope("Model", reuse = None, initializer = initializer):#tensorflow主要通过变量空间来实现共享变量

    with tf.variable_scope("r", reuse = None, initializer = initializer):

        softmax_w = tf.get_variable('softmax_w', [size, vocab_size])

        softmax_b = tf.get_variable('softmax_b', [vocab_size])

    with tf.variable_scope("RNN", reuse = None, initializer = initializer):

        for time_step in range(num_steps):

            if time_step > 0: tf.get_variable_scope().reuse_variables()

            (cell_output, state) = cell(inputs[:, time_step, :], state,)

            outputs.append(cell_output)

            

        output = tf.reshape(outputs, [-1,size])

        

        logits = tf.matmul(output, softmax_w) + softmax_b

        loss = tf.nn.seq2seq.sequence_loss_by_example([logits], [tf.reshape(targets,[-1])], [tf.ones([batch_size*num_steps])])

        

        global_step = tf.Variable(0)

        learning_rate = tf.train.exponential_decay(

        10.0, global_step, 5000, 0.1, staircase=True)

        optimizer = tf.train.GradientDescentOptimizer(learning_rate)

        gradients, v = zip(*optimizer.compute_gradients(loss))

        gradients, _ = tf.clip_by_global_norm(gradients, 1.25)

        optimizer = optimizer.apply_gradients(zip(gradients, v), global_step=global_step)

4、预测新一轮输出

teststate = test_initial_state

        (celloutput,teststate)= cell(test_inputs, teststate)

        partial_logits = tf.matmul(celloutput, softmax_w) + softmax_b

        partial_logits = tf.nn.softmax(partial_logits)

5、根据之前建立的操作,运行tensorflow会话

sv = tf.train.Supervisor(logdir=None)

with sv.managed_session() as session:

    costs = 0

    iters = 0

    for i in range(1000):

        _,l= session.run([optimizer, cost])

        costs += l

        iters +=num_steps

        perplextity = np.exp(costs / iters)

        if i%20 == 0:

            print(perplextity)

        if i%100 == 0:

            p = random_distribution()

            b = sample(p)

            sentence = id_to_word[b[0]]

            for j in range(200):

                test_output = session.run(partial_logits, feed_dict={test_input:b})

                b = sample(test_output)

                sentence += id_to_word[b[0]]

            print(sentence)    

其中,使用sv.managed_session()后,在此会话间,将不能修改graph。如果采用普通的session,程序将会阻塞于session.run(),对于这个问题,我还是很疑惑,希望理解的人帮忙解答下。

代码地址位于https://github.com/summersunshine1/datamining/tree/master/lstm,运行时只需将readdata中文件路径修改即可。作为深度学习的入门小白,希望大家多多指点。

运行结果如下:

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: