您的位置:首页 > 编程语言

1. /tools/train_net.py ( Faster-RCNN_TF代码解读)

2018-03-26 12:27 561 查看

1. /tools/train_net.py

我用的Faster-RCNN是tensorflow版本,github地址:Faster-RCNN_TF

代码运行是从train_net.py进入的。

调用函数链接:

imdb = get_imdb(args.imdb_name)中的get_imdb函数在/lib/datasets/factor.py中。

roidb = get_training_roidb(imdb)中的get_training_roidb函数在train.py中。

network = get_network(xxxxx)中的get_network函数在/lib/networks/factory.py中。

train_net(xxxxx)中的train_net在函数/lib/fast_rcnn/train.py中。

代码解读:

#!/usr/bin/env python

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Train a Fast R-CNN network on a region of interest database."""

import _init_paths
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg,cfg_from_file, cfg_from_list, get_output_dir
from datasets.factory import get_imdb
from networks.factory import get_network
import argparse
import pprint
import numpy as np
import sys
import pdb
#设置参数,dest为目标,可通过args.XXX来访问
#通过命令行调用/experiments/scripts/faster_rcnn_end2end.sh,就是在设置参数
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--device', dest='device', help='device to use',
default='cpu', type=str)
parser.add_argument('--device_id', dest='device_id', help='device id to use',
default=0, type=int)
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=70000, type=int)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='kitti_train', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
parser.add_argument('--network', dest='network_name',
help='name of the network',
default=None, type=str)
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER)
#如果sys.argv长度为1,则说明没有参数传入,系统会退出
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)

args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()

print('Called with args:')
print(args)
#如果还有其他配置文件,就加载
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)

print('Using config:')
#已知类型的前提下,可以使用pprint来标准打印
pprint.pprint(cfg)

if not args.randomize:
# fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
#imdb为存在一个字典(easydict)里的pascal_voc类的一个对象,e.g.{voc_2007_train:内容,voc_2007_val:内容,voc_2007_test:内容,voc_2007_test:内容,voc_2012_train:内容...}
#内容里有该类里的各种self名称与操作,包括roi信息等等
#get_imdb函数在/lib/datasets/factory.py中:
#[factor.py](https://blog.csdn.net/u014256231/article/details/79696391)
imdb = get_imdb(args.imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
#get_training_roidb函数其实就是将所有的bbox水平翻转一次,然后返回训练需要用的roidb
#这是一个列表,列表中存的是各个图片的字典,字典中存roi信息,字典引索为图片引索
#get_training_roidb函数在/lib/fast_rcnn/train.py中
#[train.py](https://blog.csdn.net/u014256231/article/details/79696680)
roidb = get_training_roidb(imdb)
#输出路径
output_dir = get_output_dir(imdb, None)
print 'Output will be saved to `{:s}`'.format(output_dir)
#/(args.device)(args.device_id)
device_name = '/{}:{:d}'.format(args.device,args.device_id)
print device_name
#得到网络结构,参数(包括rpn)
#get_network在函数/lib/networks/factory.py中,
#[factory.py](https://blog.csdn.net/u014256231/article/details/79696984)
network = get_network(args.network_name)
print 'Use network `{:s}` in training'.format(args.network_name)
#train_net在函数/lib/fast_rcnn/train.py中,
#[train.py](https://blog.csdn.net/u014256231/article/details/79696680)
train_net(network, imdb, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)


a1b9
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息