您的位置:首页 > 其它

用自己的数据训练Faster-RCNN,tensorflow版本(二)

2017-09-26 12:45 603 查看
我用的Faster-RCNN是tensorflow版本,fork自githubFaster-RCNN_TF

参考博客http://www.cnblogs.com/CarryPotMan/p/5390336.html

用自己的数据训练Faster-RCNN,tensorflow版本(一)中我们详细介绍了Faster-rcnn_TF中pascal_voc数据的读写接口,接下来介绍一下,如何编写自己的数据读写接口。

3、编写自己的数据读写接口

我们要用自己的数据进行训练,就得编写自己数据的读写接口,下面参考pascal_voc.py来编写。根据用自己的数据训练Faster-RCNN,tensorflow版本(一)中对pascal_voc.py文件的分析,发现,pascal_voc.py用了非常多的路径拼接,很麻烦,我们不用这么麻烦,简单一点就可以。

3.1、介绍一下我自己的训练数据集格式

我主要是从自然图片中检测出文本,因此我只有background 和text两类物体,我并没有像pascal_voc数据集里面一样每个图像用一个xml来标注,先说一下我的数据格式:

所有需要用到的数据我都放在了目录Data/ID_card/下面。

目录Data/ID_card/下面包含2个文件夹,分别是train,test。

先介绍train,目录Data/ID_card/train/里面包含:

1、所有的训练图片

2、gt_ID_card.txt

3、train.txt

我把train集合中所有图片的gt,集中放在了一个gt_ID_card.txt文件里面,gt_ID_card.txt格式如下:



以第一行为例:

ID_card/back_1.jpg: 是图片的名字;

数字1:代表该张图片上只有一个文本(text);

后面的四个数值:分别是文本框左上角和右下角的坐标。我的图片里面只有一行文本,所以只有一组文本框的坐标。

train.txt文件存放的是所有图片的名字,没有后缀,如下图:



3.2、编写自己的数据读写接口ID_card.py

主要修改的关键函数就是:def _load_annotation(self)——读取图片gt。

编写自己的数据读写接口ID_card.py,内容如下:

#coding:utf-8
# --------------------------------------------------------
#
# Written by lisiqi
# --------------------------------------------------------

import datasets
import os
import datasets.imdb
import xml.dom.minidom as minidom
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import cPickle
import subprocess

class ID_card(datasets.imdb):
def __init__(self, image_set, data_path=None):
datasets.imdb.__init__(self, 'ID_card_' + image_set) #image_set 为train或者val或者trainval或者test。
self._image_set = image_set # image_set以train为例
self._data_path = data_path # 数据所在的路径,根据传进来的参数data_path而定。传进来的参数data_path在我这里就是Data/ID_card/
self._classes = ('__background__','text') #object的类别,只有两类:背景和文本
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) #构成字典{'__background__':'0','text':'1'}
self._image_ext = '.jpg' #图片后缀
self._image_index = self._load_image_set_index() #读取train.txt,获取图片名称(该图片名称没有后缀.jpg)
# Default to roidb handler
self._roidb_handler = self.gt_roidb #获取图片的gt
# PASCAL specific config options
self.config = {'cleanup'  : True,
'use_salt' : True,
'top_k'    : 2000}

assert os.path.exists(self._data_path), \ #如果路径Data/ID_card不存在,退出
'Image Path does not exist: {}'.format(self._data_path)

def image_path_at(self, i):#获得_image_index 下标为i的图像的路径
"""
Return the absolute path to image i in the image sequence.
"""
return self.image_path_from_index(self._image_index[i])

def image_path_from_index(self, index):#根据_image_index获取图像路径
"""
Construct an image path from the image's "index" identifier.
"""
image_path = os.path.join(self._data_path, index, self._image_ext)
assert os.path.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path

def _load_image_set_index(self):#已做修改
"""
Load the indexes listed in this dataset's image set file.
得到图片名称的list。这个list里面是集合self._image_set=train中所有图片的名字(注意,图片名字没有后缀.jpg)
"""
image_set_file = os.path.join(self._data_path, self._image_set, self._image_set + '.txt')
#image_set_file是Data/ID_card/train/train.txt
#之所以要读这个train.txt文件,是因为train.txt文件里面写的是集合train中所有图片的名字(没有后缀.jpg)
assert os.path.exists(image_set_file), \
'Path does not exist: {}'.format(image_set_file)
with open(image_set_file) as f: #读取train.txt,获取图片名称(没有后缀.jpg)
image_index = [x.strip() for x in f.readlines()]
return image_index

def
4000
gt_roidb(self):
"""
Return the database of ground-truth regions of interest.
读取并返回图片gt的db。这个函数就是将图片的gt加载进来。
其中,图片的gt信息在gt_ID_card.txt文件中
并且,图片的gt被提前放在了一个.pkl文件里面。(这个.pkl文件需要我们自己生成,代码就在该函数中)

This function loads/saves from/to a cache file to speed up future calls.
之所以会将图片的gt提前放在一个.pkl文件里面,是为了不用每次都再重新读图片的gt,直接加载这个文件就可以了,可以提升速度。
"""
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):#若存在cache file则直接从cache file中读取
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb

gt_roidb = self._load_annotation()  #读入整个gt文件的具体实现
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)

return gt_roidb

#def selective_search_roidb(self):#在没有使用RPN的时候,是这样提取候选框,fast-rcnn会用到。我直接删除了这个函数,faster-rcnn用不到
#def _load_selective_search_roidb(self, gt_roidb):#用不到,删除
#def selective_search_IJCV_roidb(self):  #用不到,删除
#def _load_selective_search_IJCV_roidb(self, gt_roidb): #用不到,删除

def _load_annotation(self):
"""
Load image and bounding boxes info from txt format.
读取图片的gt的具体实现。
我把train集合中所有图片的gt,集中放在了一个gt_ID_card.txt文件里面
gt_ID_card.txt中每行的格式如下:ID_card/train/back_1.jpg 1 147 65 443 361
后面的四个数值分别是文本框左上角和右下角的坐标。我的图片里面只有一个文本,所以只有一组文本框的坐标
"""
gt_roidb = []
txtfile = os.path.join(self._data_path, 'gt_ID_card.txt')
f = open(txtfile)
split_line = f.readline().strip().split()
num = 1
while(split_line):
num_objs = int(split_line[1])
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
for i in range(num_objs):
x1 = float( split_line[2 + i * 4])
y1 = float (split_line[3 + i * 4])
x2 = float (split_line[4 + i * 4])
y2 = float (split_line[5 + i * 4])
cls = self._class_to_ind['text']
boxes[i,:] = [x1, y1, x2, y2]
gt_classes[i] = cls
overlaps[i,cls] = 1.0

overlaps = scipy.sparse.csr_matrix(overlaps)
gt_roidb.append({'boxes' : boxes, 'gt_classes': gt_classes, 'gt_overlaps' : overlaps, 'flipped' : False})
split_line = f.readline().strip().split()

f.close()
return gt_roidb

#def _write_voc_results_file(self, all_boxes):#没用,删掉
#def _do_matlab_eval(self, comp_id, output_dir='output'): #没用,删掉
#def evaluate_detections(self, all_boxes, output_dir):# 没用,删掉

def competition_mode(self, on):
if on:
self.config['use_salt'] = False
self.config['cleanup'] = False
else:
self.config['use_salt'] = True
self.config['cleanup'] = True

if __name__ == '__main__':
import datasets.ID_card #作了修改
d = datasets.ID_card('train', 'Data/ID_card/')#datasets.ID_card()在factory.py中用到了,
res = d.roidb
from IPython import embed; embed()


到这里,就完成了整个的读取接口的改写,主要在gt的读取。

除了要修改数据读写接口,还有一些文件需要修改。

3.3、修改factory.py

建议先将原来的factory.py复制成factory_bak.py作为备份,然后再在factory.py上进行修改。

修改后的factory.py如下:

"""Factory method for easily getting imdbs by name."""

import datasets.ID_card as ID_card #首先在文件头import把pascal_voc改成ID_card

__sets = {}
image_set = 'train'
data_path = '/data/home/lisiqi/Data/ID_card' #自己数据的路径

def get_imdb(name): # 当网络训练时会调用factory里面的get_imdb方法获得相应的imdb
"""Get an imdb (image database) by name."""
__sets[name] = (lambda image_set=image_set, data_path=data_path: ID_card.ID_card(image_set,data_path)) #ID_card.ID_card()的意思是调用文件ID_card.py中的类ID_card
if not __sets.has_key(name):
raise KeyError('Unknown dataset: {}'.format(name))
return __sets[name]()

def list_imdbs():
"""List all registered imdbs."""
return __sets.keys()


3.4、修改模型文件配置

3.4.1、修改config.py

工程Faster-RCNN_TF中模型的参数都在文件Faster-RCNN_TF/lib/fast-rcnn/config.py中被定义。

将config.py中有如下参数的地方,按照下面的进行修改:

# Images to use per minibatch
__C.TRAIN.IMS_PER_BATCH = 1 #每次输入到faster-rcnn网络中的图片数量是1张

# Iterations between snapshots
__C.TRAIN.SNAPSHOT_ITERS = 1000  # 训练的时候,每1000步保存一次模型。这个可以自己随意设置

__C.TRAIN.SNAPSHOT_PREFIX = 'VGGnet_faster_rcnn' #模型在保存时的名字
# Use RPN to detect objects
__C.TRAIN.HAS_RPN = True #是否使用RPN。True代表使用RPN


3.4.2、修改VGG_train.py和VGG_test.py

要想启动Faster RCNN网络训练,需要用到文件Faster-RCNN_TF/lib/networks/VGGnet_train.py。

因为我的任务是检测自然图像中的文本,所以我的检测目标物是text,那么我的类别就有两个类别即 background 和 text。

VGGnet_train.py需要修改的地方如下:

把n_classes 从原来的21类(20类+背景) ,改成 2类(人+背景),其它不用变。



3.5、启动Faster RCNN网络训练

网络的训练文件是Faster_RCNN-TF/tools/train_net.py,内容如下:

"""Train a Fast R-CNN network on a region of interest database."""

import _init_paths
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg,cfg_from_file, cfg_from_list, get_output_dir
from datasets.factory import get_imdb
from networks.factory import get_network
import argparse
import pprint
import numpy as np
import sys
import os  #新增加的
import pdb #打断点时,会用到

def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--device', dest='device', help='device to use',
default='cpu', type=str)
parser.add_argument('--device_id', dest='device_id', help='device id to use',
default=0, type=int)
parser.add_argument('--solver', dest='solver',
help='solver prototxt',
default=None, type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=70000, type=int)
parser.add_argument('--weights', dest='pretrained_model',
help='initialize with pretrained model weights',
default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='kitti_train', type=str)
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
parser.add_argument('--network', dest='network_name',
help='name of the network',
default=None, type=str)
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER)

if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)

args = parser.parse_args()
return args

if __name__ == '__main__':
args = parse_args()
print('Called with args:')
print(args)

if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)

print(
ed35
'Using config:')
pprint.pprint(cfg)
if not args.randomize:
# fix the random seeds (numpy and caffe) for reproducibility
np.random.seed(cfg.RNG_SEED)
imdb = get_imdb(args.imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
roidb = get_training_roidb(imdb)

output_dir = get_output_dir(imdb, None)
print 'Output will be saved to `{:s}`'.format(output_dir)

# 设置cpu或者gpu的id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device_id)
device_name = '/{}:{:d}'.format(args.device,args.device_id)
print device_name

network = get_network(args.network_name)
print 'Use network `{:s}` in training'.format(args.network_name)
#pdb.set_trace() #在此处设置一个断点
train_net(network, imdb, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)


在终端,启动网络训练。在路径Faster-RCNN_TF下,输入:

python ./tools/train_net.py --device gpu --device_id 3 --solver VGG_CNN_M_1024 --weight ./data/pretrain_model/VGG_imagenet.npy --imdb ID_card_train --network IDcard_train


参数解释:

train_net.py: 是网络的训练文件

—device :代表选用cpu还是gpu

—device_id: 代表机器上的cpu或者gpu的编号

—solver: 模型的配置文件,这个参数就不要进行修改了,固定就是VGG_CNN_M_1024

—weight: 初始化的权重文件,这里用的是Imagenet上预训练好的模型VGG_imagenet.npy,存放在目录./data/pretrain_model下

—imdb: 训练的数据库名字,这个名字可以自己随便起

—network: 代表选择训练网络还是测试网络,这个参数的形式是固定的,必须是IDcard_train的形式,前半部分IDcard可以随便起(但是不能有下划线),后半部分必须是_train

训练完成之后的模型默认保存在了目录./output/default/ID_card_train下。我们会发现,该目录下会出现以下文件:



TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。模型会保存在后缀为.ckpt的文件中。保存后,在目录./output/default/ID_card_train下会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。参考自TensorFlow学习笔记(8)–网络模型的保存和读取

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。


3.6、测试Faster RCNN网络训练的模型

参考./tools/demo.py,写自己的demo.py。

由于我所使用的服务器中无法使用plot,所以我将检测的坐标结果直接画在了测试图片上,并且将图片保存在了目录./results下。

修改后的demo.py内容如下:

import _init_paths
import tensorflow as tf
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import os, sys, cv2
import argparse
from networks.factory import get_network
import glob

import os
#os.environ['CUDA_VISIBLE_DEVICES']='3'

import pdb #设断点时,使用的

CLASSES = ('__background__', 'text')物体类别

def vis_detections(im, class_name, dets, image_name, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
im = im.copy()

for i in inds:
bbox = dets[i, :4] #检测图片的坐标
score = dets[i, -1]

cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 3) #将坐标直接画在了检测图片上

cv2.imwrite('./results/' + image_name, im) #将带有框的检测图片保存在目录./results下
print class_name

def demo(sess, net, image_name_path):
"""Detect object classes in an image using pre-computed object proposals."""

# load images
im = cv2.imread(image_name_path)
im_name = os.path.basename(image_name_path)

# Detect all object classes and regress object bounds
timer = Timer()
timer.tic()
scores, boxes = im_detect(sess, net, im)
print boxes
timer.toc()
print ('Detection took {:.3f}s for '
'{:d} object proposals').format(timer.total_time, boxes.shape[0])

# Visualize detections for each class
#im = im[:, :, (2, 1, 0)]
#fig, ax = plt.subplots(figsize=(12, 12))
# ax.imshow(im, aspect='equal')

CONF_THRESH = 0.5
NMS_THRESH = 0.3
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
#vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)
vis_detections(im, cls, dets, im_name, thresh=CONF_THRESH)
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Faster R-CNN demo')
parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--cpu', dest='cpu_mode',
help='Use CPU mode (overrides --gpu)',
action='store_true')
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
default='VGGnet_test')
parser.add_argument('--model', dest='model', help='Model path',
default='/data/home/lisiqi/Faster-RCNN_TF_original/weight/VGGnet_fast_rcnn_iter_70000.ckpt')
parser.add_argument('--results_dir', dest='results_dir', help='Results director',
default=' ')

args = parser.parse_args()

return args
if __name__ == '__main__':
cfg.TEST.HAS_RPN = True  # Use RPN for proposals

args = parse_args()

# GPU id(设置GPU的编号)
os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu_id)

if args.model == ' ':
raise IOError(('Error: Model not found.\n'))

if args.results_dir == ' ':
raise IOError(('Error: Result director not found.\n'))

# init session
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# load network
net = get_network(args.demo_net)
# load model
#saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)
#pdb.set_trace()
saver = tf.train.Saver()
saver.restore(sess, args.model) #加载训练好的模型,名称就写到.ckpt就行,例如VGGnet_faster_rcnn_iter_1000.ckpt

#sess.run(tf.initialize_all_variables())

print '\n\nLoaded network {:s}'.format(args.model)

# Warmup on a dummy image
im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
for i in xrange(2):
_, _= im_detect(sess, net, im)

# load images
im_file_dir = cfg.DATA_DIR + '/demo/'
im_names_path = glob.glob(im_file_dir + '*.jpg')
#pdb.set_trace()

for im_name_path in im_names_path:
#im_name = os.path.basename(im_name_path)
print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
print 'Demo for data/demo/{}'.format(im_name_path)
#pdb.set_trace()
demo(sess, net, im_name_path)

print 'results_dir:{}'.format(args.results_dir)
#plt.show()


在终端,路径Faster-RCNN_TF下,输入:

python ./tools/demo.py --gpu 3 --model ./output/default/ID_card_train/VGGnet_faster_rcnn_iter_1000.ckpt --results ./results/


参数解释:

demo.py: 测试图片的文件

—gpu : 代表机器上gpu的编号(直接就默认使用gpu,没有cpu选项)

—model: 网络训练好的模型

—results: 保存结果的路径
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: