用自己的数据训练Faster-RCNN,tensorflow版本(二)
2017-09-26 12:45
603 查看
我用的Faster-RCNN是tensorflow版本,fork自githubFaster-RCNN_TF
参考博客http://www.cnblogs.com/CarryPotMan/p/5390336.html
在用自己的数据训练Faster-RCNN,tensorflow版本(一)中我们详细介绍了Faster-rcnn_TF中pascal_voc数据的读写接口,接下来介绍一下,如何编写自己的数据读写接口。
3.1、介绍一下我自己的训练数据集格式
我主要是从自然图片中检测出文本,因此我只有background 和text两类物体,我并没有像pascal_voc数据集里面一样每个图像用一个xml来标注,先说一下我的数据格式:
所有需要用到的数据我都放在了目录Data/ID_card/下面。
目录Data/ID_card/下面包含2个文件夹,分别是train,test。
先介绍train,目录Data/ID_card/train/里面包含:
1、所有的训练图片
2、gt_ID_card.txt
3、train.txt
我把train集合中所有图片的gt,集中放在了一个gt_ID_card.txt文件里面,gt_ID_card.txt格式如下:
以第一行为例:
ID_card/back_1.jpg: 是图片的名字;
数字1:代表该张图片上只有一个文本(text);
后面的四个数值:分别是文本框左上角和右下角的坐标。我的图片里面只有一行文本,所以只有一组文本框的坐标。
train.txt文件存放的是所有图片的名字,没有后缀,如下图:
3.2、编写自己的数据读写接口ID_card.py
主要修改的关键函数就是:def _load_annotation(self)——读取图片gt。
编写自己的数据读写接口ID_card.py,内容如下:
到这里,就完成了整个的读取接口的改写,主要在gt的读取。
除了要修改数据读写接口,还有一些文件需要修改。
3.3、修改factory.py
建议先将原来的factory.py复制成factory_bak.py作为备份,然后再在factory.py上进行修改。
修改后的factory.py如下:
3.4、修改模型文件配置
3.4.1、修改config.py
工程Faster-RCNN_TF中模型的参数都在文件Faster-RCNN_TF/lib/fast-rcnn/config.py中被定义。
将config.py中有如下参数的地方,按照下面的进行修改:
3.4.2、修改VGG_train.py和VGG_test.py
要想启动Faster RCNN网络训练,需要用到文件Faster-RCNN_TF/lib/networks/VGGnet_train.py。
因为我的任务是检测自然图像中的文本,所以我的检测目标物是text,那么我的类别就有两个类别即 background 和 text。
VGGnet_train.py需要修改的地方如下:
把n_classes 从原来的21类(20类+背景) ,改成 2类(人+背景),其它不用变。
3.5、启动Faster RCNN网络训练
网络的训练文件是Faster_RCNN-TF/tools/train_net.py,内容如下:
在终端,启动网络训练。在路径Faster-RCNN_TF下,输入:
参数解释:
train_net.py: 是网络的训练文件
—device :代表选用cpu还是gpu
—device_id: 代表机器上的cpu或者gpu的编号
—solver: 模型的配置文件,这个参数就不要进行修改了,固定就是VGG_CNN_M_1024
—weight: 初始化的权重文件,这里用的是Imagenet上预训练好的模型VGG_imagenet.npy,存放在目录./data/pretrain_model下
—imdb: 训练的数据库名字,这个名字可以自己随便起
—network: 代表选择训练网络还是测试网络,这个参数的形式是固定的,必须是IDcard_train的形式,前半部分IDcard可以随便起(但是不能有下划线),后半部分必须是_train
训练完成之后的模型默认保存在了目录./output/default/ID_card_train下。我们会发现,该目录下会出现以下文件:
TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。模型会保存在后缀为.ckpt的文件中。保存后,在目录./output/default/ID_card_train下会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。参考自TensorFlow学习笔记(8)–网络模型的保存和读取
3.6、测试Faster RCNN网络训练的模型
参考./tools/demo.py,写自己的demo.py。
由于我所使用的服务器中无法使用plot,所以我将检测的坐标结果直接画在了测试图片上,并且将图片保存在了目录./results下。
修改后的demo.py内容如下:
在终端,路径Faster-RCNN_TF下,输入:
参数解释:
demo.py: 测试图片的文件
—gpu : 代表机器上gpu的编号(直接就默认使用gpu,没有cpu选项)
—model: 网络训练好的模型
—results: 保存结果的路径
参考博客http://www.cnblogs.com/CarryPotMan/p/5390336.html
在用自己的数据训练Faster-RCNN,tensorflow版本(一)中我们详细介绍了Faster-rcnn_TF中pascal_voc数据的读写接口,接下来介绍一下,如何编写自己的数据读写接口。
3、编写自己的数据读写接口
我们要用自己的数据进行训练,就得编写自己数据的读写接口,下面参考pascal_voc.py来编写。根据用自己的数据训练Faster-RCNN,tensorflow版本(一)中对pascal_voc.py文件的分析,发现,pascal_voc.py用了非常多的路径拼接,很麻烦,我们不用这么麻烦,简单一点就可以。3.1、介绍一下我自己的训练数据集格式
我主要是从自然图片中检测出文本,因此我只有background 和text两类物体,我并没有像pascal_voc数据集里面一样每个图像用一个xml来标注,先说一下我的数据格式:
所有需要用到的数据我都放在了目录Data/ID_card/下面。
目录Data/ID_card/下面包含2个文件夹,分别是train,test。
先介绍train,目录Data/ID_card/train/里面包含:
1、所有的训练图片
2、gt_ID_card.txt
3、train.txt
我把train集合中所有图片的gt,集中放在了一个gt_ID_card.txt文件里面,gt_ID_card.txt格式如下:
以第一行为例:
ID_card/back_1.jpg: 是图片的名字;
数字1:代表该张图片上只有一个文本(text);
后面的四个数值:分别是文本框左上角和右下角的坐标。我的图片里面只有一行文本,所以只有一组文本框的坐标。
train.txt文件存放的是所有图片的名字,没有后缀,如下图:
3.2、编写自己的数据读写接口ID_card.py
主要修改的关键函数就是:def _load_annotation(self)——读取图片gt。
编写自己的数据读写接口ID_card.py,内容如下:
#coding:utf-8 # -------------------------------------------------------- # # Written by lisiqi # -------------------------------------------------------- import datasets import os import datasets.imdb import xml.dom.minidom as minidom import numpy as np import scipy.sparse import scipy.io as sio import utils.cython_bbox import cPickle import subprocess class ID_card(datasets.imdb): def __init__(self, image_set, data_path=None): datasets.imdb.__init__(self, 'ID_card_' + image_set) #image_set 为train或者val或者trainval或者test。 self._image_set = image_set # image_set以train为例 self._data_path = data_path # 数据所在的路径,根据传进来的参数data_path而定。传进来的参数data_path在我这里就是Data/ID_card/ self._classes = ('__background__','text') #object的类别,只有两类:背景和文本 self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) #构成字典{'__background__':'0','text':'1'} self._image_ext = '.jpg' #图片后缀 self._image_index = self._load_image_set_index() #读取train.txt,获取图片名称(该图片名称没有后缀.jpg) # Default to roidb handler self._roidb_handler = self.gt_roidb #获取图片的gt # PASCAL specific config options self.config = {'cleanup' : True, 'use_salt' : True, 'top_k' : 2000} assert os.path.exists(self._data_path), \ #如果路径Data/ID_card不存在,退出 'Image Path does not exist: {}'.format(self._data_path) def image_path_at(self, i):#获得_image_index 下标为i的图像的路径 """ Return the absolute path to image i in the image sequence. """ return self.image_path_from_index(self._image_index[i]) def image_path_from_index(self, index):#根据_image_index获取图像路径 """ Construct an image path from the image's "index" identifier. """ image_path = os.path.join(self._data_path, index, self._image_ext) assert os.path.exists(image_path), \ 'Path does not exist: {}'.format(image_path) return image_path def _load_image_set_index(self):#已做修改 """ Load the indexes listed in this dataset's image set file. 得到图片名称的list。这个list里面是集合self._image_set=train中所有图片的名字(注意,图片名字没有后缀.jpg) """ image_set_file = os.path.join(self._data_path, self._image_set, self._image_set + '.txt') #image_set_file是Data/ID_card/train/train.txt #之所以要读这个train.txt文件,是因为train.txt文件里面写的是集合train中所有图片的名字(没有后缀.jpg) assert os.path.exists(image_set_file), \ 'Path does not exist: {}'.format(image_set_file) with open(image_set_file) as f: #读取train.txt,获取图片名称(没有后缀.jpg) image_index = [x.strip() for x in f.readlines()] return image_index def 4000 gt_roidb(self): """ Return the database of ground-truth regions of interest. 读取并返回图片gt的db。这个函数就是将图片的gt加载进来。 其中,图片的gt信息在gt_ID_card.txt文件中 并且,图片的gt被提前放在了一个.pkl文件里面。(这个.pkl文件需要我们自己生成,代码就在该函数中) This function loads/saves from/to a cache file to speed up future calls. 之所以会将图片的gt提前放在一个.pkl文件里面,是为了不用每次都再重新读图片的gt,直接加载这个文件就可以了,可以提升速度。 """ cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') if os.path.exists(cache_file):#若存在cache file则直接从cache file中读取 with open(cache_file, 'rb') as fid: roidb = cPickle.load(fid) print '{} gt roidb loaded from {}'.format(self.name, cache_file) return roidb gt_roidb = self._load_annotation() #读入整个gt文件的具体实现 with open(cache_file, 'wb') as fid: cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL) print 'wrote gt roidb to {}'.format(cache_file) return gt_roidb #def selective_search_roidb(self):#在没有使用RPN的时候,是这样提取候选框,fast-rcnn会用到。我直接删除了这个函数,faster-rcnn用不到 #def _load_selective_search_roidb(self, gt_roidb):#用不到,删除 #def selective_search_IJCV_roidb(self): #用不到,删除 #def _load_selective_search_IJCV_roidb(self, gt_roidb): #用不到,删除 def _load_annotation(self): """ Load image and bounding boxes info from txt format. 读取图片的gt的具体实现。 我把train集合中所有图片的gt,集中放在了一个gt_ID_card.txt文件里面 gt_ID_card.txt中每行的格式如下:ID_card/train/back_1.jpg 1 147 65 443 361 后面的四个数值分别是文本框左上角和右下角的坐标。我的图片里面只有一个文本,所以只有一组文本框的坐标 """ gt_roidb = [] txtfile = os.path.join(self._data_path, 'gt_ID_card.txt') f = open(txtfile) split_line = f.readline().strip().split() num = 1 while(split_line): num_objs = int(split_line[1]) boxes = np.zeros((num_objs, 4), dtype=np.uint16) gt_classes = np.zeros((num_objs), dtype=np.int32) overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32) for i in range(num_objs): x1 = float( split_line[2 + i * 4]) y1 = float (split_line[3 + i * 4]) x2 = float (split_line[4 + i * 4]) y2 = float (split_line[5 + i * 4]) cls = self._class_to_ind['text'] boxes[i,:] = [x1, y1, x2, y2] gt_classes[i] = cls overlaps[i,cls] = 1.0 overlaps = scipy.sparse.csr_matrix(overlaps) gt_roidb.append({'boxes' : boxes, 'gt_classes': gt_classes, 'gt_overlaps' : overlaps, 'flipped' : False}) split_line = f.readline().strip().split() f.close() return gt_roidb #def _write_voc_results_file(self, all_boxes):#没用,删掉 #def _do_matlab_eval(self, comp_id, output_dir='output'): #没用,删掉 #def evaluate_detections(self, all_boxes, output_dir):# 没用,删掉 def competition_mode(self, on): if on: self.config['use_salt'] = False self.config['cleanup'] = False else: self.config['use_salt'] = True self.config['cleanup'] = True if __name__ == '__main__': import datasets.ID_card #作了修改 d = datasets.ID_card('train', 'Data/ID_card/')#datasets.ID_card()在factory.py中用到了, res = d.roidb from IPython import embed; embed()
到这里,就完成了整个的读取接口的改写,主要在gt的读取。
除了要修改数据读写接口,还有一些文件需要修改。
3.3、修改factory.py
建议先将原来的factory.py复制成factory_bak.py作为备份,然后再在factory.py上进行修改。
修改后的factory.py如下:
"""Factory method for easily getting imdbs by name.""" import datasets.ID_card as ID_card #首先在文件头import把pascal_voc改成ID_card __sets = {} image_set = 'train' data_path = '/data/home/lisiqi/Data/ID_card' #自己数据的路径 def get_imdb(name): # 当网络训练时会调用factory里面的get_imdb方法获得相应的imdb """Get an imdb (image database) by name.""" __sets[name] = (lambda image_set=image_set, data_path=data_path: ID_card.ID_card(image_set,data_path)) #ID_card.ID_card()的意思是调用文件ID_card.py中的类ID_card if not __sets.has_key(name): raise KeyError('Unknown dataset: {}'.format(name)) return __sets[name]() def list_imdbs(): """List all registered imdbs.""" return __sets.keys()
3.4、修改模型文件配置
3.4.1、修改config.py
工程Faster-RCNN_TF中模型的参数都在文件Faster-RCNN_TF/lib/fast-rcnn/config.py中被定义。
将config.py中有如下参数的地方,按照下面的进行修改:
# Images to use per minibatch __C.TRAIN.IMS_PER_BATCH = 1 #每次输入到faster-rcnn网络中的图片数量是1张 # Iterations between snapshots __C.TRAIN.SNAPSHOT_ITERS = 1000 # 训练的时候,每1000步保存一次模型。这个可以自己随意设置 __C.TRAIN.SNAPSHOT_PREFIX = 'VGGnet_faster_rcnn' #模型在保存时的名字 # Use RPN to detect objects __C.TRAIN.HAS_RPN = True #是否使用RPN。True代表使用RPN
3.4.2、修改VGG_train.py和VGG_test.py
要想启动Faster RCNN网络训练,需要用到文件Faster-RCNN_TF/lib/networks/VGGnet_train.py。
因为我的任务是检测自然图像中的文本,所以我的检测目标物是text,那么我的类别就有两个类别即 background 和 text。
VGGnet_train.py需要修改的地方如下:
把n_classes 从原来的21类(20类+背景) ,改成 2类(人+背景),其它不用变。
3.5、启动Faster RCNN网络训练
网络的训练文件是Faster_RCNN-TF/tools/train_net.py,内容如下:
"""Train a Fast R-CNN network on a region of interest database.""" import _init_paths from fast_rcnn.train import get_training_roidb, train_net from fast_rcnn.config import cfg,cfg_from_file, cfg_from_list, get_output_dir from datasets.factory import get_imdb from networks.factory import get_network import argparse import pprint import numpy as np import sys import os #新增加的 import pdb #打断点时,会用到 def parse_args(): """ Parse input arguments """ parser = argparse.ArgumentParser(description='Train a Fast R-CNN network') parser.add_argument('--device', dest='device', help='device to use', default='cpu', type=str) parser.add_argument('--device_id', dest='device_id', help='device id to use', default=0, type=int) parser.add_argument('--solver', dest='solver', help='solver prototxt', default=None, type=str) parser.add_argument('--iters', dest='max_iters', help='number of iterations to train', default=70000, type=int) parser.add_argument('--weights', dest='pretrained_model', help='initialize with pretrained model weights', default=None, type=str) parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str) parser.add_argument('--imdb', dest='imdb_name', help='dataset to train on', default='kitti_train', type=str) parser.add_argument('--rand', dest='randomize', help='randomize (do not use a fixed seed)', action='store_true') parser.add_argument('--network', dest='network_name', help='name of the network', default=None, type=str) parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER) if len(sys.argv) == 1: parser.print_help() sys.exit(1) args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() print('Called with args:') print(args) if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) print( ed35 'Using config:') pprint.pprint(cfg) if not args.randomize: # fix the random seeds (numpy and caffe) for reproducibility np.random.seed(cfg.RNG_SEED) imdb = get_imdb(args.imdb_name) print 'Loaded dataset `{:s}` for training'.format(imdb.name) roidb = get_training_roidb(imdb) output_dir = get_output_dir(imdb, None) print 'Output will be saved to `{:s}`'.format(output_dir) # 设置cpu或者gpu的id os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device_id) device_name = '/{}:{:d}'.format(args.device,args.device_id) print device_name network = get_network(args.network_name) print 'Use network `{:s}` in training'.format(args.network_name) #pdb.set_trace() #在此处设置一个断点 train_net(network, imdb, roidb, output_dir, pretrained_model=args.pretrained_model, max_iters=args.max_iters)
在终端,启动网络训练。在路径Faster-RCNN_TF下,输入:
python ./tools/train_net.py --device gpu --device_id 3 --solver VGG_CNN_M_1024 --weight ./data/pretrain_model/VGG_imagenet.npy --imdb ID_card_train --network IDcard_train
参数解释:
train_net.py: 是网络的训练文件
—device :代表选用cpu还是gpu
—device_id: 代表机器上的cpu或者gpu的编号
—solver: 模型的配置文件,这个参数就不要进行修改了,固定就是VGG_CNN_M_1024
—weight: 初始化的权重文件,这里用的是Imagenet上预训练好的模型VGG_imagenet.npy,存放在目录./data/pretrain_model下
—imdb: 训练的数据库名字,这个名字可以自己随便起
—network: 代表选择训练网络还是测试网络,这个参数的形式是固定的,必须是IDcard_train的形式,前半部分IDcard可以随便起(但是不能有下划线),后半部分必须是_train
训练完成之后的模型默认保存在了目录./output/default/ID_card_train下。我们会发现,该目录下会出现以下文件:
TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。模型会保存在后缀为.ckpt的文件中。保存后,在目录./output/default/ID_card_train下会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。参考自TensorFlow学习笔记(8)–网络模型的保存和读取
checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer. model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构 TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。 model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。
3.6、测试Faster RCNN网络训练的模型
参考./tools/demo.py,写自己的demo.py。
由于我所使用的服务器中无法使用plot,所以我将检测的坐标结果直接画在了测试图片上,并且将图片保存在了目录./results下。
修改后的demo.py内容如下:
import _init_paths import tensorflow as tf from fast_rcnn.config import cfg from fast_rcnn.test import im_detect from fast_rcnn.nms_wrapper import nms from utils.timer import Timer import matplotlib.pyplot as plt import numpy as np import os, sys, cv2 import argparse from networks.factory import get_network import glob import os #os.environ['CUDA_VISIBLE_DEVICES']='3' import pdb #设断点时,使用的 CLASSES = ('__background__', 'text')物体类别 def vis_detections(im, class_name, dets, image_name, thresh=0.5): """Draw detected bounding boxes.""" inds = np.where(dets[:, -1] >= thresh)[0] if len(inds) == 0: return im = im.copy() for i in inds: bbox = dets[i, :4] #检测图片的坐标 score = dets[i, -1] cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 3) #将坐标直接画在了检测图片上 cv2.imwrite('./results/' + image_name, im) #将带有框的检测图片保存在目录./results下 print class_name def demo(sess, net, image_name_path): """Detect object classes in an image using pre-computed object proposals.""" # load images im = cv2.imread(image_name_path) im_name = os.path.basename(image_name_path) # Detect all object classes and regress object bounds timer = Timer() timer.tic() scores, boxes = im_detect(sess, net, im) print boxes timer.toc() print ('Detection took {:.3f}s for ' '{:d} object proposals').format(timer.total_time, boxes.shape[0]) # Visualize detections for each class #im = im[:, :, (2, 1, 0)] #fig, ax = plt.subplots(figsize=(12, 12)) # ax.imshow(im, aspect='equal') CONF_THRESH = 0.5 NMS_THRESH = 0.3 for cls_ind, cls in enumerate(CLASSES[1:]): cls_ind += 1 # because we skipped background cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)] cls_scores = scores[:, cls_ind] dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32) keep = nms(dets, NMS_THRESH) dets = dets[keep, :] #vis_detections(im, cls, dets, ax, thresh=CONF_THRESH) vis_detections(im, cls, dets, im_name, thresh=CONF_THRESH) def parse_args(): """Parse input arguments.""" parser = argparse.ArgumentParser(description='Faster R-CNN demo') parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int) parser.add_argument('--cpu', dest='cpu_mode', help='Use CPU mode (overrides --gpu)', action='store_true') parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]', default='VGGnet_test') parser.add_argument('--model', dest='model', help='Model path', default='/data/home/lisiqi/Faster-RCNN_TF_original/weight/VGGnet_fast_rcnn_iter_70000.ckpt') parser.add_argument('--results_dir', dest='results_dir', help='Results director', default=' ') args = parser.parse_args() return args if __name__ == '__main__': cfg.TEST.HAS_RPN = True # Use RPN for proposals args = parse_args() # GPU id(设置GPU的编号) os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu_id) if args.model == ' ': raise IOError(('Error: Model not found.\n')) if args.results_dir == ' ': raise IOError(('Error: Result director not found.\n')) # init session sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) # load network net = get_network(args.demo_net) # load model #saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) #pdb.set_trace() saver = tf.train.Saver() saver.restore(sess, args.model) #加载训练好的模型,名称就写到.ckpt就行,例如VGGnet_faster_rcnn_iter_1000.ckpt #sess.run(tf.initialize_all_variables()) print '\n\nLoaded network {:s}'.format(args.model) # Warmup on a dummy image im = 128 * np.ones((300, 300, 3), dtype=np.uint8) for i in xrange(2): _, _= im_detect(sess, net, im) # load images im_file_dir = cfg.DATA_DIR + '/demo/' im_names_path = glob.glob(im_file_dir + '*.jpg') #pdb.set_trace() for im_name_path in im_names_path: #im_name = os.path.basename(im_name_path) print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' print 'Demo for data/demo/{}'.format(im_name_path) #pdb.set_trace() demo(sess, net, im_name_path) print 'results_dir:{}'.format(args.results_dir) #plt.show()
在终端,路径Faster-RCNN_TF下,输入:
python ./tools/demo.py --gpu 3 --model ./output/default/ID_card_train/VGGnet_faster_rcnn_iter_1000.ckpt --results ./results/
参数解释:
demo.py: 测试图片的文件
—gpu : 代表机器上gpu的编号(直接就默认使用gpu,没有cpu选项)
—model: 网络训练好的模型
—results: 保存结果的路径
相关文章推荐
- 用自己的数据训练Faster-RCNN,tensorflow版本(一)
- tensorflow版本 Faster RCNN训练自己的数据集
- Tensorflow框架下Faster-RCNN实践(二)——用自己制作的数据训练Faster-RCNN网络(附代码)
- faster-rcnn安装,训练自己的数据
- faster rcnn训练自己地数据时遇到地问题
- Tensorflow框架下Faster-RCNN实践(二)——用自己制作的数据训练Faster-RCNN网络(附代码)
- Ubuntu16.04 caffe py-faster-rcnn安装以及训练自己的数据
- Caffe学习系列——Faster-RCNN训练自己的数据集
- Ubuntu16.04+cuda8.0+cudnn5.1配置faster-rcnn的方法以及训练自己的数据出现的问题
- 使用py-faster-rcnn训练自己的数据
- 【faster-rcnn】训练自己的数据——修改图片格式、类别
- 使用自己的数据训练Faster-RCNN
- 用自己的数据训练Faster-RCNN
- faster-rcnn训练和测试自己的数据(VGG/ResNet)以及遇到的问题
- faster-rcnn 中训练自己的数据出现的错误
- win7 faster_rcnn 训练自己的数据 matlab
- faster rcnn训练自己地数据时遇到地问题
- faster rcnn训练自己地数据时遇到地问题
- faster rcnn训练自己地数据时遇到地问题