您的位置:首页 > 其它

tf.nn.dynamic_rnn()实现的一个例子。

2018-01-28 13:27 281 查看
tf.nn.dynamic_rnn()在处理变长输入时特别方便,具体解释可以看这篇博文

#coding=utf-8
import tensorflow as tf
import numpy as np
# 创建输入数据
X = np.random.randn(2, 10, 8)

# 第二个example长度为6
X[1,6:] = 0
X_lengths = [10, 6]

cell = tf.contrib.rnn.BasicLSTMCell(num_units=5, state_is_tuple=True)

outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)

result = tf.contrib.learn.run_n(
{"outputs": outputs, "last_states": last_states},
n=1,
feed_dict=None)
a = result[0]
print(a)

assert result[0]["outputs"].shape == (2, 10, 5)

# 第二个example中的outputs超过6步(7-10步)的值应该为0
assert (result[0]["outputs"][1,7,:] == np.zeros(cell.output_size)).all()


输出:

[[ 0.13337141  0.10697078  0.11238842 -0.16187296  0.04447445]
[ 0.06800554  0.29581101  0.14454009 -0.11857419 -0.08062822]
[-0.02766501  0.20230338  0.25521379 -0.10196185  0.02908   ]
[ 0.07160553  0.39891538  0.03997988 -0.43861938 -0.00340179]
[-0.12841535  0.35346241  0.08577594 -0.29574161 -0.06306395]
[-0.19022677  0.11256105 -0.13190501 -0.20170257 -0.02765217]
[-0.04303006 -0.42253068 -0.02945417 -0.0817529   0.03569792]
[-0.01433148  0.00066725 -0.08619441 -0.1063433   0.36421112]
[ 0.19718385  0.06653057  0.02880462 -0.31631752  0.04064322]
[ 0.07665874  0.15330013  0.11820727 -0.28386946 -0.06841132]]
-------------------------------------------------------
[[ 0.09817442  0.12635493  0.14153314 -0.13827174 -0.14350587]
[ 0.09484242  0.05155221  0.11429032 -0.04175748 -0.11621833]
[ 0.21802519  0.17491722  0.17653461 -0.2161642  -0.17876485]
[ 0.05461165 -0.01181785  0.31818148 -0.18725258 -0.06083239]
[ 0.03753194  0.04578742  0.30538616 -0.09413831 -0.41238963]
[-0.04687686  0.01701181  0.21276684 -0.02761401 -0.07971509]
[ 0.          0.          0.          0.          0.        ]
[ 0.          0.          0.          0.          0.        ]
[ 0.          0.          0.          0.          0.        ]
[ 0.          0.          0.          0.          0.        ]]


参考资料:

tensorflow高阶教程:tf.dynamic_rnn
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  tensorflow RNN