您的位置:首页 > 理论基础 > 计算机网络

利用 caffe 接口构建 CNN 网络

2016-08-29 14:53 375 查看
我的研究重点原本是在 Torch 上,我也很喜欢用 Torch 去实现网络。但最近不得不转到 caffe 上。

在实现上篇博文:论文阅读:Reading Text in the Wild with Convolutional Neural Networks 的代码时,在 Bounding Box Regression 部分,需要用 caffe 来实现这个网络。

而一开始我构建论文中提到的这个 CNN 网络时,并没有用 caffe 提供的接口,而是直接手写
train.prototxt
test.prototxt
文件。结果除了很多错,caffe 的 layers 名称也在变化,所以极其不推荐这种方式。我这个网络较浅时还好,一旦网络很深,写起来累死人。

我将这个网络用 caffe 中的
draw_net.py
脚本生成图像,网络结构如下:



调用 caffe 的 python 接口,网络构建方式如下:

import os, sys, glob

CAFFE_PATH = '/home/chenxp/caffe/python'
sys.path.append(CAFFE_PATH)
import caffe
import caffe.io

from caffe import layers as L
from caffe import params as P

import h5py

def Bounding_Box_Reg(hdf5, batch_size):
n = caffe.NetSpec()
n.data, n.label = L.HDF5Data(batch_size=batch_size, source=hdf5, ntop=2)
n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=64, pad=2, weight_filler=dict(type='xavier'))
n.relu1 = L.ReLU(n.conv1, in_place=True)
n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=128, pad=2, weight_filler=dict(type='xavier'))
n.relu2 = L.ReLU(n.conv2, in_place=True)
n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.conv3 = L.Convolution(n.pool2, kernel_size=3, num_output=256, pad=1, weight_filler=dict(type='xavier'))
n.relu3 = L.ReLU(n.conv3, in_place=True)
n.pool3 = L.Pooling(n.conv3, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.conv4 = L.Convolution(n.pool3, kernel_size=3, num_output=512, pad=1, weight_filler=dict(type='xavier'))
n.relu4 = L.ReLU(n.conv4, in_place=True)
n.pool4 = L.Pooling(n.conv4, kernel_size=2, stride=2, pool=P.Pooling.MAX)
n.ip1   = L.InnerProduct(n.pool4, num_output=4000, weight_filler=dict(type='xavier'))
n.dp1   = L.Dropout(n.ip1, dropout_ratio=0.5)
n.ip2   = L.InnerProduct(n.ip1, num_output=4, weight_filler=dict(type='xavier'))
n.loss  = L.EuclideanLoss(n.ip2, n.label)

return n.to_proto()

with open('BBR_train.prototxt', 'w') as f:
f.write(str(Bounding_Box_Reg('train.h5', 16)))

with open('BBR_test.prototxt', 'w') as f:
f.write(str(Bounding_Box_Reg('test.h5', 16)))


最后生成两个
prototxt
文件,一个是
BBR_train.prototxt
,另一个是
BBR_test.prototxt
文件。

下图是自动生成的
BBR_train.prototxt
文件:



这样的方式更快速,还不容易出错~^_^
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: