您的位置:首页 > 其它

如何编写训练测试的prototxt配置文件---以Resnet为例

2016-12-12 21:27 501 查看
import os
os.chdir('/home/wuwl/ResNet')
import init_path
import caffe
import numpy as np
import tools
from caffe import layers as L,params as P,to_proto

this_dir = os.path.abspath(".")
def ResNet(split):
train_data_file = this_dir + '/caffe-master/examples/cifar10/cifar10_train_lmdb'
test_data_file = this_dir + '/caffe-master/examples/cifar10/cifar10_test_lmdb'
mean_file = this_dir + '/caffe-master/examples/cifar10/mean.binaryproto'

if split == 'train':
data,labels = L.Data(source = train_data_file, #训练样本的路径
backend = P.Data.LMDB, #训练样本的格式
batch_size = 128,
ntop = 2,              # 输出的数目
transform_param = dict(mean_file = mean_file,
crop_size = 28,
#只有训练才旋转
mirror = True))
else:
data,labels = L.Data(source = test_data_file,   #测试样本的路径
backend = P.Data.LMDB,
batch_size = 128,
ntop = 2,
transform_param = dict(mean_file = mean_file,
crop_size = 28))

repeat = 3
scale,result = conv_BN_scale_relu(split,data,nout = 16,ks = 3,stride = 1,pad = 1)
for i in range(repeat):
projection_stride = 1
result = ResNet_block(split,result,nout = 16,ks = 3,stride = 1,
projection_stride = projection_stride,pad = 1)
for i in range(repeat):
if i == 0:
projection_stride = 2    #直通部分
else:
projection_stride = 1    #正常卷积
result = ResNet_block(split,result,nout = 32,ks = 3,stride = 1,
projection_stride = projection_stride,pad = 1)
for i in range(repeat):
if i == 0:
projection_stride = 2    #直通部分
else:
projection_stride = 1    #正常卷积
result = ResNet_block(split,result,nout = 64,ks = 3,stride = 1,
projection_stride = projection_stride,pad = 1)
pool = L.Pooling(result,pool = P.Pooling.AVE,global_pooling = True)
IP = L.InnerProduct(pool,num_output = 10,
weight_filler = dict(type = 'xavier'),
bias_filler = dict(type = 'constant'))
acc = L.Accuracy(IP,labels)
loss = L.SoftmaxWithLoss(IP,labels)
return to_proto(acc,loss)

def conv_BN_scale_relu(split,bottom,nout,ks,stride,pad):
conv = L.Convolution(bottom,kernel_size = ks,stride = stride,num_output = nout,
pad = pad,bias_term = True,
weight_filler = dict(type = 'xavier'),
bias_filler = dict(type = 'constant'))
if split == "train":
use_global_stats = False
else:
use_global_stats = True
BN = L.BatchNorm(conv,batch_norm_param = dict(use_global_stats = use_global_stats),
in_place = True,
param = [dict(lr_mult = 0,decay_mult = 0),
dict(lr_mult = 0,decay_mult = 0),
dict(lr_mult = 0,decay_mult = 0)])
scale = L.Scale(BN,scale_param = dict(bias_term = True),in_place = True)
relu = L.ReLU(scale,in_place = True)
return scale,relu

def ResNet_block(split,bottom,nout,ks,stride,projection_stride,pad):
if projection_stride == 1:
scale0 = bottom
else:
scale0,relu0 = conv_BN_scale_relu(split,bottom,nout,1,projection_stride,0)

scale1,relu1 = conv_BN_scale_relu(split,bottom,nout,ks,projection_stride,pad)
scale2,relu2 = conv_BN_scale_relu(split,relu1,nout,ks,stride,pad)
wise = L.Eltwise(scale2,scale0,operation = P.Eltwise.SUM)
wise_relu = L.ReLU(wise,in_place = True)
return wise_relu

def make_net():
with open(this_dir + '/res_net_model/train.prototxt','w') as f:
f.write(str(ResNet('train')))

with open(this_dir + '/res_net_model/test.prototxt','w') as f:
f.write(str(ResNet('test')))

if __name__ == '__main__':
make_net()
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐