您的位置:首页 > 运维架构

tensorflow API: nn.in_top_k 、top_k

2018-01-03 18:21 447 查看
原文连接

对例子解释了一下,直接来吧:

tf.nn.top_k(input, k, name=None)


解释:这个函数的作用是返回 input 中每行最大的 k 个数,并且返回它们所在位置的索引。

import tensorflow as tf
import numpy as np

input = tf.constant(np.random.rand(3,4))
k = 2
"""输出的每行最大的k个数,还有k个数的索引等信息"""
output = tf.nn.top_k(input, k)
with tf.Session() as sess:
print(sess.run(input))
print(sess.run(output))


输出:

[[ 0.11658417  0.0049587   0.34396945  0.80061182]
[ 0.94435975  0.54798914  0.52284388  0.05966983]
[ 0.44605413  0.06890732  0.67666671  0.05019359]]
TopKV2(values=array([[ 0.80061182,  0.34396945],
[ 0.94435975,  0.54798914],
[ 0.67666671,  0.44605413]]), indices=array([[3, 2],
[0, 1],
[2, 0]]))


tf.nn.in_top_k(predictions, targets, k, name=None)


解释:这个函数的作用是返回一个布尔向量,说明目标值targets是否存在于预测值predictions之中。

输出数据是一个 targets 长度的布尔向量,如果目标值存在于预测值之中,那么 out[i] = true。

注意:targets 是predictions中的索引位,并不是 predictions 中具体的值。

import tensorflow as tf
import numpy as np

input = tf.constant(np.random.rand(3,4), tf.float32)
k = 2
"""给出targets=[1,1,1],里面的每个元素代表各行的索引范围,在这个范围内的数是否在该行的最大k个数里。"""
output = tf.nn.in_top_k(input, [1,1,1], k)
with tf.Session() as sess:
print(sess.run(input))
print(sess.run(output))


输出:

"""第二行中的前两个都在最大的k(2)个里,所以是true"""
[[ 0.96177077  0.53093106  0.05769594  0.63572675]
[ 0.66133296  0.79599202  0.31054178  0.60332161]
[ 0.79283494  0.4200708   0.22985983  0.60160613]]
[False  True False]
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: