您的位置:首页 > 其它

TensorFlow自定义Layer

2018-02-02 16:49 190 查看
import numpy as np
import tensorflow as tf
import os
from tensorflow.python.framework import ops
ops.reset_default_graph()

sess = tf.Session()

x_shape = [1, 4, 4, 1]
x_val = np.random.uniform(size=x_shape)

x_data = tf.placeholder(tf.float32, shape=x_shape)

my_filter = tf.constant(0.25, shape=[2, 2, 1, 1])
my_strides = [1, 2, 2, 1]
mov_avg_layer = tf.nn.conv2d(x_data, my_filter, my_strides,
padding='SAME', name='Moving_Avg_Window')

# 多层Layer
def custom_layer(input_matrix):
input_matrix_sqeezed = tf.squeeze(input_matrix)
A = tf.constant([[1., 2.], [-1., 3.]])
b = tf.constant(1., shape=[2, 2])
temp1 = tf.matmul(A, input_matrix_sqeezed)
temp = tf.add(temp1, b)  # Ax+b
return tf.sigmoid(temp)

with tf.name_scope('Custom_Layer') as scope:
custom_layer1 = custom_layer(mov_avg_layer)

print(sess.run(custom_layer1, feed_dict={x_data: x_val}))

'''
[[0.9270415  0.91142696]
[0.86319745 0.9007292 ]]
'''

merged = tf.summary.merge_all(key='summaries')

if not os.path.exists('tensorboard_logs/'):
os.makedirs('tensorboard_logs/')

my_writer = tf.summary.FileWriter('tensorboard_logs/', sess.graph).flush()


内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: