您的位置:首页 > Web前端

使用py-faster-rcnn训练自己的数据

2016-05-28 11:02 645 查看

使用py-faster-rcnn训练自己的数据

faster-rcnn原本使用voc2007数据集训练,可识别20类目标。但是我们的应用场景,可能只需要2、3类。假设我们只需要检测识别两类目标
person
car
,再加上
backgroud
可以认为是分3类。那么我们如何训练这个模型呢。

1. 准备数据集

可以把自己的数据改成voc2007的格式,主要有两个文件:
trianval.txt
存放训练图片的名字,
annotation.xml
存储标注信息。具体过程,这里有一个参考。

然后在补充一点,如果你没有足够训练数据,或者干脆就只用voc2007的数据,而且只训练
person
car
,该怎么办呢。

很简单,
voc2007/Imagesets/Main
文件夹下,有各个类别的txt文件,我们需要使用的是
person_trainval.txt
car_trainval.txt
。这两个文件内容如下:

002247  1
002248  1
002249 -1
002251 -1


前边的数字串使图片的名字,后边的 1 或 -1 代表正负该图片是否存在该目标(person或car)。我们只使用标注为1的图片,使用shell语句完成:

cat person_trainval.txt car_trainval.txt | awk '{if($2 == 1) {print $1}}' >> my_trainval.txt


2. 修改
pascal_voc.py
文件

修改前注意备份

- classes:

def __init__(self, image_set, year, devkit_path=None):
imdb.__init__(self, 'voc_' + year + '_' + image_set)
self._year = year
self._image_set = image_set
self._devkit_path = self._get_default_path() if devkit_path is None \
else devkit_path
self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
self._classes = ('__background__', 'person', 'car')

self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
self._image_ext = '.jpg'
self._image_index = self._load_image_set_index()
# Default to roidb handler
self._roidb_handler = self.selective_search_roidb
self._salt = str(uuid.uuid4())
self._comp_id = 'comp4'

# PASCAL specific config options
self.config = {'cleanup'     : True,
'use_salt'    : True,
'use_diff'    : False,
'matlab_eval' : False,
'rpn_file'    : None,
'min_size'    : 2}

assert os.path.exists(self._devkit_path), \
'VOCdevkit path does not exist: {}'.format(self._devkit_path)
assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)


annotation

def _load_pascal_annotation(self, index):
"""
Load image and bounding boxes info from XML file in the PASCAL VOC
format.
"""
filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
tree = ET.parse(filename)
objs = tree.findall('object')
if not self.config['use_diff']:
# Exclude the samples labeled as difficult
non_diff_objs = [
obj for obj in objs if (int(obj.find('difficult').text) == 0 and obj.find('name').text == 'person' or obj.find('name').text == 'car')]
# if len(non_diff_objs) != len(objs):
#     print 'Removed {} difficult objects'.format(
#         len(objs) - len(non_diff_objs))
objs = non_diff_objs
num_objs = len(objs)

boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
# "Seg" area for pascal is just the box area
seg_areas = np.zeros((num_objs), dtype=np.float32)


3. 修改
train.prototxt
test.prototxt

有两种训练方式
alt_opt
end2end
,请根据情况细心修改,注意目标类别个数是 3,cls的输出是 3,box的输出是 3*4 = 12。这里有一个参考。

做完以上,就可以启动训练了,good luck…

如果要使用cpu训练,参考http://blog.csdn.net/qq_14975217/article/details/51495844
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息