您的位置:首页 > 编程语言

Faster-RCNN_TF代码解读14:roi_data_layer/layer.py

2017-09-19 09:30 344 查看
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""The data layer used during training to train a Fast R-CNN network.

RoIDataLayer implements a Caffe Python layer.
"""

from fast_rcnn.config import cfg
from roi_data_layer.minibatch import get_minibatch
import numpy as np

class RoIDataLayer(object):
"""Fast R-CNN data layer used for training."""

def __init__(self, roidb, num_classes):
"""Set the roidb to be used by this layer during training."""
self._roidb = roidb
self._num_classes = num_classes
self._shuffle_roidb_inds()

def _shuffle_roidb_inds(self):
"""Randomly permute the training roidb."""
#将np.arange(len(self._roidb))随机排序,返回该随机排序的array
self._perm = np.random.permutation(np.arange(len(self._roidb)))
self._cur = 0

#依次取batch大小个roi的index
def _get_next_minibatch_inds(self):
"""Return the roidb indices for the next minibatch."""

#False
if cfg.TRAIN.HAS_RPN:
if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):
self._shuffle_roidb_inds()

db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]
self._cur += cfg.TRAIN.IMS_PER_BATCH
else:
# sample images
#cfg.TRAIN.IMS_PER_BATCH为2
#self._perm为该随机排序的array,self._cur为目前取到了多少个roi
db_inds = np.zeros((cfg.TRAIN.IMS_PER_BATCH), dtype=np.int32)
i = 0
while (i < cfg.TRAIN.IMS_PER_BATCH):
ind = self._perm[self._cur]
#当前图片中有多少个物体
num_objs = self._roidb[ind]['boxes'].shape[0]
if num_objs != 0:
db_inds[i] = ind
i += 1

self._cur += 1
if self._cur >= len(self._roidb):
self._shuffle_roidb_inds()

return db_inds

def _get_next_minibatch(self):
"""Return the blobs to be used for the next minibatch.

If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a
separate process and made available through self._blob_queue.
"""
#取到下一个batch的引索
db_inds = self._get_next_minibatch_inds()
#把对应引索的图像信息(dict)取出,放去一个列表
minibatch_db = [self._roidb[i] for i in db_inds]
return get_minibatch(minibatch_db, self._num_classes)

def forward(self):
"""Get blobs and copy them into this layer's top blob vector."""
blobs = self._get_next_minibatch()
return blobs
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息