您的位置:首页 > 其它

利用TensorFlow训练简单的RNN

2017-12-21 19:55 381 查看
最近在学习 Hands-On Mathine Learning with Scikit_learn & TensorFlow,所以资料代码来源于本书chaper 14

学习这本书的好处在于都是讲基础,但是本书给了一个完整的设计流程,其中包含很多参考文献。

RNN基本内容不再陈述,下面进行简单的实例。

MNIST数据集识别

熟悉一下mnist的格式:

from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
mnist = input_data.read_data_sets("./data/MNIST_data/", one_hot=True)
print mnist.train.images.shape
print mnist.train.labels.shape
print mnist.validation.images.shape
print mnist.validation.labels.shape
print mnist.test.images.shape
print mnist.test.labels.shape
print mnist.train.labels[1]

x=mnist.train.images[1].reshape((28,28))
plt.figure()
plt.imshow(x)
plt.show()

mnist1 = input_data.read_data_sets("./data/MNIST_data/")
print mnist1.train.labels.shape
print mnist1.train.labels[1]


输出结果:

Extracting ./data/MNIST_data/train-images-idx3-ubyte.gz

Extracting ./data/MNIST_data/train-labels-idx1-ubyte.gz

Extracting ./data/MNIST_data/t10k-images-idx3-ubyte.gz

Extracting ./data/MNIST_data/t10k-labels-idx1-ubyte.gz

(55000, 784)

(55000, 10)

(5000, 784)

(5000, 10)

(10000, 784)

(10000, 10)

[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]



Extracting ./data/MNIST_data/train-images-idx3-ubyte.gz

Extracting ./data/MNIST_data/train-labels-idx1-ubyte.gz

Extracting ./data/MNIST_data/t10k-images-idx3-ubyte.gz

Extracting ./data/MNIST_data/t10k-labels-idx1-ubyte.gz

(55000,)

3

one_hot编码把mnist.train.labels标签3转换成向量[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]

RNN结构

代码用了含有150个循环神经单元的隐层,总共有28个步长时间(由于数据集每张图包含28 x 28,输入x是长度28的向量),最后还加上一个含有10个神经元的全连接层,用来输出结果。



#导入数据
import tensorflow as tf
import numpy as np
from tensorflow.contrib.layers import fully_connected #全连接层
from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()  #reset graph to avoid error

#加载MNIST数据集
mnist = input_data.read_data_sets("./data/MNIST_data/")

n_steps=28   #步长
n_inputs=28 #输入数据个数
n_neurons=150 #每层神经元的数量
n_outputs=10  #输出数据

learning_rate=0.001

X=tf.placeholder(tf.float32,[None,n_steps,n_inputs])#输入32步和32个X输入
y=tf.placeholder(tf.int32,[None])

#he_init=tf.contrib.layers.variance_scaling_initializer()#He initialization 参数初始化
#with tf.variable_scope("rnn",initializer=he_init):
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)

logits=fully_connected(states,n_outputs,activation_fn=None)
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)

loss=tf.reduce_mean(xentropy)
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate)

training_op=optimizer.minimize(loss)
correct=tf.nn.in_top_k(logits,y,1)
accuracy=tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()

n_epochs=100
batch_size=150

X_test=mnist.test.images.reshape((-1,n_steps,n_inputs))  #转换成n个28x28的测试集
y_test=mnist.test.labels

with tf.Session() as sess:
init.run()
for epoch in range(n_epochs):
for iteration in range(batch_size):
x_batch,y_batch=mnist.train.next_batch(batch_size)
x_batch=x_batch.reshape((-1,n_steps,n_inputs))#转换成batch_size个28x28的输入
sess.run(training_op,feed_dict={X:x_batch,y:y_batch})
acc_train=accuracy.eval(feed_dict={X:x_batch,y:y_batch})
test_train=accuracy.eval(feed_dict={X: X_test, y: y_test})
if epoch%10 == 0:
print(epoch,"train accuracy:",acc_train,"Test accuracy",test_train)


模型训练结果

(0, 'train accuracy:', 0.88666666, 'Test accuracy', 0.85790002)
(10, 'train accuracy:', 0.96666664, 'Test accuracy', 0.96520001)
(20, 'train accuracy:', 0.97333336, 'Test accuracy', 0.97000003)
(30, 'train accuracy:', 0.98666668, 'Test accuracy', 0.972)
(40, 'train accuracy:', 0.99333334, 'Test accuracy', 0.97430003)
(50, 'train accuracy:', 0.99333334, 'Test accuracy', 0.977)
(60, 'train accuracy:', 0.96666664, 'Test accuracy', 0.97259998)
(70, 'train accuracy:', 0.98666668, 'Test accuracy', 0.97689998)
(80, 'train accuracy:', 0.99333334, 'Test accuracy', 0.97920001)
(90, 'train accuracy:', 1.0, 'Test accuracy', 0.97960001)


RNN的反向传播训练(to be continued)

参考Hands-On Mathine Learning with Scikit_learn & TensorFlow
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: