您的位置:首页 > 其它

tf.reduce_sum (API r1.3)

2017-11-18 17:03 363 查看

tf.reduce_sum (API r1.3)

1. tf.reduce_sum

reduce_sum(
input_tensor,
axis=None,
keep_dims=False,
name=None,
reduction_indices=None
)

Defined in tensorflow/python/ops/math_ops.py.

See the guide: Math > Reduction

Computes the sum of elements across dimensions of a tensor.

Reduces input_tensor along the dimensions given in axis. Unless keep_dims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keep_dims is true, the reduced dimensions are retained with length 1.

If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.

For example:

# 'x' is [[1, 1, 1]
#         [1, 1, 1]]
tf.reduce_sum(x) ==> 6
tf.reduce_sum(x, 0) ==> [2, 2, 2]
tf.reduce_sum(x, 1) ==> [3, 3]
tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]
tf.reduce_sum(x, [0, 1]) ==> 6

Args:
input_tensor: The tensor to reduce. Should have numeric type.
axis: The dimensions to reduce. If None (the default), reduces all dimensions.
keep_dims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.

Returns:

The reduced tensor.



numpy compatibility:

Equivalent to np.sum

2. example 1

import tensorflow as tf
import numpy as np

t1 = tf.constant([[0, 1, 2], [3, 4, 5]], dtype=np.float32)

rs0 = tf.reduce_sum(t1)
rs1 = tf.reduce_sum(t1, 0)
rs2 = tf.reduce_sum(t1, 1)
rs3 = tf.reduce_sum(t1, 1, keep_dims=True)
rs4 = tf.reduce_sum(t1, [0, 1])

with tf.Session() as sess:
input_t1 = sess.run(t1)
print("input_t1.shape:")
print(input_t1.shape)
print("input_t1:")
print(input_t1)
print('\n')

output0 = sess.run(rs0)
print("output0.shape:")
print(output0.shape)
print("output0:")
print(output0)
print('\n')

output1 = sess.run(rs1)
print("output1.shape:")
print(output1.shape)
print("output1:")
print(output1)
print('\n')

output2 = sess.run(rs2)
print("output2.shape:")
print(output2.shape)
print("output2:")
print(output2)
print('\n')

output3 = sess.run(rs3)
print("output3.shape:")
print(output3.shape)
print("output3:")
print(output3)
print('\n')

output4 = sess.run(rs4)
print("output4.shape:")
print(output4.shape)
print("output4:")
print(output4)

output:

input_t1.shape:
(2, 3)
input_t1:
[[ 0.  1.  2.]
[ 3.  4.  5.]]

output0.shape:
()
output0:
15.0

output1.shape:
(3,)
output1:
[ 3.  5.  7.]

output2.shape:
(2,)
output2:
[  3.  12.]

output3.shape:
(2, 1)
output3:
[[  3.]
[ 12.]]

output4.shape:
()
output4:
15.0

Process finished with exit code 0


3.
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: