您的位置:首页 > 理论基础 > 计算机网络

resnet__残差神经网络搭建

2017-08-26 19:36 627 查看
# -*- coding: utf-8 -*-
import tensorflow as  tf
from collections import namedtuple
from math import sqrt

def print_activations(t):
print(t.op.name,t.get_shape().as_list())

def conv2d(x,n_fliters,k_h=5,k_w=5,
stride_h=2,stride_w=2,
stddev=0.02,activation=lambda x:x,
bias=True,padding='SAME',name="Conv2D"):
with tf.variable_scope(name):
w=tf.get_variable('weight',[k_h,k_w,x.get_shape()[-1],n_fliters],
initializer=tf.truncated_normal_initializer(stddev=stddev))
tf.summary.histogram(name+'weight',w)
conv=tf.nn.conv2d(x,w,strides=[1,stride_h,stride_w,1],padding=padding)
if bias:
b=tf.get_variable('bias',[n_fliters],
initializer=tf.truncated_normal_initializer(stddev=stddev))
tf.summary.histogram(name+'bias',b)
conv=conv+b
print_activations(conv)
return activation(conv)

def linear(x,n_units,scope=None,stddev=0.02,
activation=tf.identity):
shape=x.get_shape().as_list()
with tf.variable_scope(scope or "linear"):
weight =tf.get_variable("weight",[shape[1],n_uni
4000
ts],tf.float32,
tf.random_normal_initializer(stddev=stddev))
tf.summary.histogram('weight',weight)
bias=tf.get_variable('bias',[n_units],tf.float32,tf.random_normal_initializer(stddev=stddev))
tf.summary.histogram(tf.matmul(x,weight)+bias)

def ResNet(x,n_outputs,activation=tf.nn.relu):
LayerBlock=namedtuple('LayerBlock',['num_repeats','num_fiters','bottleneck_size']) #创建Block的类只包含数据结构,不包含具体方法。
blocks=[LayerBlock(3,128,32),
LayerBlock(3,256,64),
LayerBlock(3,512,128),
LayerBlock(3,1024,256),
LayerBlock(3,2048,512),
LayerBlock(3,4096,1024)]
input_shape=x.get_shape().as_list()
if len(input_shape)==2:
ndim=int(sqrt(input_shape[1]))
if ndim*ndim !=input_shape[1]:
raise ValueError('input_shape should be square')
x=tf.reshape(x,[-1,ndim,ndim,1])
tf.summary.image('input',x,10)
net=conv2d(x,64,k_h=7,k_w=7,name='conv1',activation=activation) #第一卷积扩展到64个信道和下采样

net=tf.nn.max_pool(net,[1,2,2,1],strides=[1,2,2,1],padding='SAME')

print_activations(net)

net=conv2d(net,blocks[0].num_fiters,k_h=1,k_w=1,
stride_h=1,stride_w=1,padding='VAlID',name='conv2')  #建设残差神经网络

for blocks_i,block in enumerate(blocks):     #循环 res blocks
for repeat_i in  range(block.num_repeats):
name='block_%d/repeat_%d'%(blocks_i,repeat_i)
conv=conv2d(net,block.bottleneck_size,k_h=1,k_w=1,
padding='VALID',stride_h=1,stride_w=1,
activation=activation,name=name+'/conv_in')
conv=conv2d(conv,block.bottleneck_size,k_h=3,k_w=3,
padding='VALID',stride_h=1,stride_w=1,
activation=activation,
name=name+'/conv_bottleneck')
conv=conv2d(conv,block.num_fiters,k_h=1,k_w=1,
padding='VALID',stride_h=1,stride_w=1,
activation=activation,
name=name+'/conv_out')
net=conv+net

try:
next_block=blocks[blocks_i+1]

net=conv2d(net,next_block.num_fiters,k_h=3,k_w=3,
padding='SAME',stride_h=1,stride_w=1,
name='blcok_%d/conv_upscale' % blocks_i)
except IndexError:
pass
net=tf.nn.avg_pool(net,ksize=[1,net.get_shape().as_list()[1],net.get_shape().as_list()[2],1],
strides=[1,1,1,1],padding='VALID')
print_activations(net)
net=tf.reshape(net,[-1,net.get_shape().as_list()[1]*net.get_shape().as_list()[2],1],
strides=[1,1,1,1],padding='VALID')
print_activations(net)

net=linear(net,n_outputs)

return net
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: