Faster-RCNN_TF代码解读14:roi_data_layer/layer.py
2017-09-19 09:30
344 查看
# -------------------------------------------------------- # Fast R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- """The data layer used during training to train a Fast R-CNN network. RoIDataLayer implements a Caffe Python layer. """ from fast_rcnn.config import cfg from roi_data_layer.minibatch import get_minibatch import numpy as np class RoIDataLayer(object): """Fast R-CNN data layer used for training.""" def __init__(self, roidb, num_classes): """Set the roidb to be used by this layer during training.""" self._roidb = roidb self._num_classes = num_classes self._shuffle_roidb_inds() def _shuffle_roidb_inds(self): """Randomly permute the training roidb.""" #将np.arange(len(self._roidb))随机排序,返回该随机排序的array self._perm = np.random.permutation(np.arange(len(self._roidb))) self._cur = 0 #依次取batch大小个roi的index def _get_next_minibatch_inds(self): """Return the roidb indices for the next minibatch.""" #False if cfg.TRAIN.HAS_RPN: if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb): self._shuffle_roidb_inds() db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH] self._cur += cfg.TRAIN.IMS_PER_BATCH else: # sample images #cfg.TRAIN.IMS_PER_BATCH为2 #self._perm为该随机排序的array,self._cur为目前取到了多少个roi db_inds = np.zeros((cfg.TRAIN.IMS_PER_BATCH), dtype=np.int32) i = 0 while (i < cfg.TRAIN.IMS_PER_BATCH): ind = self._perm[self._cur] #当前图片中有多少个物体 num_objs = self._roidb[ind]['boxes'].shape[0] if num_objs != 0: db_inds[i] = ind i += 1 self._cur += 1 if self._cur >= len(self._roidb): self._shuffle_roidb_inds() return db_inds def _get_next_minibatch(self): """Return the blobs to be used for the next minibatch. If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a separate process and made available through self._blob_queue. """ #取到下一个batch的引索 db_inds = self._get_next_minibatch_inds() #把对应引索的图像信息(dict)取出,放去一个列表 minibatch_db = [self._roidb[i] for i in db_inds] return get_minibatch(minibatch_db, self._num_classes) def forward(self): """Get blobs and copy them into this layer's top blob vector.""" blobs = self._get_next_minibatch() return blobs
相关文章推荐
- Faster-RCNN_TF代码解读16:roi_data_layer/roidb.py
- Faster-RCNN_TF代码解读15:roi_data_layer/minibatch.py
- Faster-RCNN_TF代码解读17:anchor_target_layer_tf.py
- 7. anchor_target_layer_tf.py ( Faster-RCNN_TF代码解读)
- 9. proposal_target_layer_tf.py ( Faster-RCNN_TF代码解读)
- Faster-RCNN_TF代码解读9:proposal_target_layer_tf.py
- Faster-RCNN_TF代码解读10:proposal_layer_tf.py
- Faster-RCNN_TF代码解读2:datasets/factory.py
- Faster-RCNN_TF代码解读1:train-net.py
- Faster-RCNN_TF代码解读3:train.py
- Faster-RCNN_TF代码解读4:config.py
- Faster-RCNN_TF代码解读11:imdb.py
- Faster-RCNN_TF代码解读18:generate_anchors.py
- 5. VGGnet_train.py ( Faster-RCNN_TF代码解读)
- Faster-RCNN_TF代码解读5:networks/factory.py
- Faster-RCNN_TF代码解读6:pascal_voc.py
- Faster-RCNN_TF代码解读12:bbox_transform.py
- Faster-RCNN_TF代码解读20:blob.py
- 6. network.py ( Faster-RCNN_TF代码解读)
- 1. /tools/train_net.py ( Faster-RCNN_TF代码解读)