您的位置:首页 > 理论基础 > 计算机网络

利用tensorflow训练自己的图片——2、网络搭建(AlexNet)

2018-01-26 15:44 537 查看
得到数据之后,接下来就是网络的搭建,我在这里将模型单独定义出来,方便后期的网络修正。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

"""
Spyder Editor
This is a temporary script file.
filename:DR_alexnet.py
creat time:2018年1月16日
author:huangxudong
"""
import tensorflow as tf
import numpy as np
#define different layer function
def maxPoolLayer(x,kHeight,kWidth,strideX,strideY,name,padding="SAME"):
return tf.nn.max_pool(x,ksize=[1,kHeight,kWidth,1],strides=[1,strideX,strideY,1],padding=padding,name=name)

def dropout(x,keepPro,name=None):
return tf.nn.dropout(x,keepPro,name)

def LRN(x,R,alpha,beta,name=None,bias=1.0): #局部相应归一化
return tf.nn.local_response_normalization(x,depth_radius=R,alpha=alpha,
beta=beta,bias=bias,name=name)
def fcLayer(x,inputD,outputD,reluFlag,name):
with tf.variable_scope(name) as scope:
w=tf.get_variable("w",shape=[inputD,outputD]) #shape就是变量维度
b=tf.get_variable("b",[outputD])
out=tf.nn.xw_plus_b(x,w,b,name=scope.name)
if reluFlag:
return tf.nn.relu(out)
else:
return out
def convLayer(x,kHeight,kWidth,strideX,strideY,featureNum,name,padding="SAME",groups=1):
"""convolution"""
channel=int(x.get_shape()[-1]) #x数组的最后一个数
conv=lambda a,b: tf.nn.conv2d(a,b,strides=[1,strideY,strideX,1],padding=padding) #匿名函数
with tf.variable_scope(name) as scope:
w=tf.get_variable("w",shape=[kHeight,kWidth,channel/groups,featureNum])
b=tf.get_variable("b",shape=[featureNum])
xNew=tf.split(value=x,num_or_size_splits=groups,axis=3)
wNew=tf.split(value=w,num_or_size_splits=groups,axis=3)
featureMap=[conv(t1,t2) for t1,t2 in zip(xNew,wNew)]
mergeFeatureMap=tf.concat(axis=3,values=featureMap)
out=tf.nn.bias_add(mergeFeatureMap,b)
# print(mergeFeatureMap.get_shape().as_list(),out.shape)
return tf.nn.relu(out,name=scope.name) #卷积激活一起完成,out大小和mergeFeatureMap一样,不需要reshape
class alexNet(object):
"""alexNet model"""
def __init__(self,x,keepPro,classNum,skip,modelPath="bvlc_alexnet.npy"):
self.X=x
self.KEEPPRO=keepPro #表示类名
self.CLASSNUM=classNum
self.SKIP=skip
self.MODELPATH=modelPath
#build CNN
self.buildCNN()
def buildCNN(self): #重点,模型搭建 2800*2100
x1=tf.reshape(self.X,shape=[-1,512,512,3])
# print(x1.shape)
conv1=convLayer(x1,7,7,3,3,256,"conv1","VALID") #169*169
lrn1=LRN(conv1,2,2e-05,0.75,"norm1")
pool1=maxPoolLayer(lrn1,3,3,2,2,"pool1","VALID") #84*84

conv2=convLayer(pool1,3,3,1,1,512,"conv2","VALID") #82*82
lrn2=LRN(conv2,2,2e-05,0.75,"norm2")
pool2=maxPoolLayer(lrn2,3,3,2,2,"pool2","VALID") #40*40

conv3=convLayer(pool2,3,3,1,1,1024,"conv3","VALID") #38*38
conv4=convLayer(conv3,3,3,1,1,1024,"conv4","VALID") #36*36

conv5=convLayer(conv4,3,3,2,2,512,"conv5","VALID") #17*17
pool5=maxPoolLayer(conv5,3,3,2,2,"pool5","VALID") #8*8
# print(pool5.shape)
fcIn=tf.reshape(pool5,[-1,512*8*8])
fc1=fcLayer(fcIn,512*8*8,4096,True,"fc6")
dropout1=dropout(fc1,self.KEEPPRO)

fc2=fcLayer(dropout1,4096,4096,True,"fc7")
dropout2=dropout(fc2,self.KEEPPRO)

self.fc3=fcLayer(dropout2,4096,self.CLASSNUM,True,"fc8")上面便是网络的搭建,搭好之后还需要将模型加载出来:
def loadModel(self,sess):
"""load model"""
wDict=np.load(self.MODELPATH,encoding="bytes").item()
for name in wDict:
if name not in self.SKIP:
with tf.variable_scope(name, reuse = True):
for p in wDict[name]:
if len(p.shape) == 1:
#bias
sess.run(tf.get_variable('b', trainable = False).assign(p))
else:
#weights
sess.run(tf.get_variable('w', trainable = False).assign(p))
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  tensorflow
相关文章推荐