您的位置:首页 > Web前端

Caffe小玩意(3)-利用py-faster-rcnn自定义输入数据

2016-07-01 12:20 459 查看

Caffe小玩意(3)-利用py-faster-rcnn自定义输入数据

众所周知,caffe是现有deep learning framework中最为自动化的,我们甚至可以只定义prototxt文件而不需要写代码,就完成整个网络的训练。正是由于它的高度自动化,当我们想要修改其中的模块,就不是一件容易的事了。

caffe本身自带了一些标准通用的dataset,我们可以比较简单地使用它们。此外,对于一些其他的输入形式,caffe也给出了一些指示:

http://caffe.berkeleyvision.org/tutorial/data.html

http://caffe.berkeleyvision.org/tutorial/layers.html#data-layers

但是,对于那种label不是简单变量的输入,我们应该怎么输入到caffe里呢?(例如:显著性检测问题,我们的label应该是一幅灰度图像;人体关节检测问题,我们的label应该是一个tensor)。那么今天,我们就来看看如何利用rbg大神的py-faster-rcnn框架来自己定制输入数据:

https://github.com/rbgirshick/py-faster-rcnn

首先,pull这个repository到本地目录(之后的./就代表在本地下的这个目录),然后运行./data/目录下的script下载数据(这些数据本身不是必要的,只是因为我之前需要finetune模型将它们下载了下来,之后的路径、操作等等也基于这一事实)。

好了,现在会多了一些目录出来。我们需要将数据(假设图像输入就是.jpg文件,label是python的numpy array,即.npy文件)放到相应的地方,即./data/VOCdevkit2007/VOC2007/JPEGImages。但是呢,这个目录下是原来的VOC2007数据集的图像输入,所以我建议在这个目录下再新建一个目录(这里叫dlib)。因此实际存放路径是:

./data/VOCdevkit2007/VOC2007/JPEGImagesd/dlib


之后,我们需要为这些输入数据写xml文件。每一份输入图像都对应一个xml文件,内容如下(不需要注重格式):

<annotation><folder>VOC2007</folder><filename>image_0046.jpg</filename><source><database>dlib facial landmark</database><annotation>Yuliang Zou</annotation></source><size><width>400</width><height>300</height><depth>3</depth></size><segmented>0</segmented></annotation>


同样地,为了与原来数据的xml文件混淆,新建一个dlib文件夹,因此这些新xml文件的存放路径为:

./data/VOCdevkit2007/VOC2007/Annotations/dlib


以上的操作,都没有对dataset进行training set与test set的区分,下面我们就来完成这件事。打开目录:

./data/VOCdevkit2007/VOC2007/ImageSets/Main


我们可以看到很多的txt文件,先把原来的trainval.txt与test.txt备份好。然后,新建自己的trainval.txt与test.txt,每一行都是输入图像的名称,例:

dlib/100032540_1

dlib/1002681492_1

dlib/1004467229_1

...


(直接用这两个txt文件的名字,是因为改用新的会有点麻烦,详见附录)

之后,我们需要修改相应的python代码使得数据可以顺利导入。

(1)在
lib/roi_data_layer/layer.py
里的setup()函数,我们需要添加如下代码,为label分配空间:

top[idx].reshape(cfg.TRAIN.IMS_PER_BATCH, 68, 38, 50)
self._name_to_top_map['heatmap'] = idx
idx += 1


我这里的label是facial landmark,一共有68个2-d array。然后把下面的一些不需要的部分删掉(不然之后可能会报错)。

(2)在
.lib/utils/blob.py
里新定义一个函数:

def heatmap_list_to_blob(hms):
""" Convert a list of heat maps into a network input."""
num_hms = len(hms)
blob = np.zeros((num_hms, 68, 38, 50), dtype=np.float32)
for i in xrange(num_hms):
hm = hms[i]
blob[i] = hm.transpose((2,0,1))
return blob


这个函数可以将包含若干label的python list转换为caffe的blob数据结构。

(3)在
lib/roi_data_layer/minibatch.py
里导入刚刚定义的heatmap_list_to_blob函数,然后新定义函数:

def _get_heatmap_blob(roidb):
"""Get a batch of heat maps"""
num_images = len(roidb)
hms = []
for i in xrange(num_images):
hm = np.load(roidb[i]['heatmap'])
hms.append(hm)

# Create a blob to hold the input heat maps
blob = heatmap_list_to_blob(hms)

return blob


然后,在get_minibatch()函数中加入如下几行代码:

# Get the imput heat map blob, formatted for caffe
hm_blob = _get_heatmap_blob(roidb)
blobs['heatmap'] = hm_blob


(4)在
./lib/roi_data_layer/roidb.py
的prepare_roidb()函数中这行代码之后:

roidb[i]['image'] = imdb.image_path_at(i)


加入这么一行:

roidb[i]['heatmap'] = roidb[i]['image'][0:len(roidb[i]['image'])-3] + 'npy'


相信看到这里大家也知道了,
imdb.image_path_at(i)
获取的是输入图像的完整路径,我们进行些许修改就可以得到label的完整路径。

最后,我们需要修改train.prototxt,这个按自己的需要定制就可以了,比较简单,就不详述了。

在最后之后,如果要测试性能,需要自己对
./lib/fast_rcnn/test.py
进行修改。这里不再详述,我相信当你成功地开始训练的时候,已经对这些内容比较了解了,可以比较容易地写出自己需要的版本。

当然,完成了以上的所有步骤之后,可能还是会出现某些问题。

1.毕竟我的xml文件比原来的简化了不少,可以按实际情况删掉相应的code(原来的代码可能会导入xml文件的一些参数,但是我省略了那些参数)

2.原来代码对于输入图像的scaling比较奇怪,那边有可能会出错。对于某些输入尺寸固定的dataset,或许你可以修改

lib/roi_data_layer/layer.py
里的setup()函数,其中会有一行

top[idx].reshape(cfg.TRAIN.IMS_PER_BATCH, 3, _, _)


最后的两个参数是height和width,按需要修改。

最近折腾这个东西也折腾了很久,甚是头疼,更是加深了我对rbg大神的仰慕之情。行文有些混乱,如果有不明白的欢迎留言,大家一起交流。

附录:

./lib/datasets/factory.py
这份代码负责构造dataset

line 15 - 20:

# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007', '2012']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: pascal_voc(split, year))


./experiments/scripts/faster_rcnn_end2end.sh
这份bash文件负责指定训练与测试时所用的dataset

line 27 - 28:

TRAIN_IMDB="voc_2007_trainval"
TEST_IMDB="voc_2007_test"
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: