您的位置:首页 > 移动开发 > Objective-C

TensorFlow Object Detection API教程——利用自己制作的数据集进行训练预测和测试

2018-01-04 15:43 1561 查看

感想

如果制作数据集不清楚的,请参考我的博客:http://blog.csdn.net/w5688414/article/details/78970874,这里我介绍一下,怎样利用自己的数据集进行训练啦。我用的是python3,ubuntu16.04来跑模型的
我建议最好把官网的demo跑通以后,熟悉一下流程,然后进行下面的操作,不然很容易走进误区。

训练

首先git下载tensorflow models模块:
git clone https://github.com/tensorflow/models.git[/code]这个文件有点大,建议找个网速快一点网下载好,然后添加环境变量,这里我添加的示例为: 
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
export PYTHONPATH="${PYTHONPATH}:/home/whsyxt/Downloads/gaoshengwu/models/research:/home/whsyxt/Downloads/gaoshengwu/models/research/slim/"
这样每次登录都会生效,不添加环境变量也行,只是用起来不怎么方便,官网给的方法也行:

# From tensorflow/models/research/
protoc object_detection/protos/*.proto --python_out=.
# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

进入research目录,运行:
sudo python3 setup.py install

官网教程的地址也附上:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md
其它的安装参考这个网址里面的就行了。

然后,把生成好的tfrecords文件放在./models/research/object_detection/data
在./models/research/object_detection目录下创建training文件夹,里面再创建文件夹ssd_inception_v2_whsyxt文件夹,然后创建label map文件,我的label map文件为whsyxt_label_map.pbtxt,内容为:
item {
id: 2
name: 'person'
}

item {
id: 1
name: 'car'
}
然后把ssd_inception_v2_coco.config文件放在ssd_inception_v2_whsyxt文件夹下,修改里面的配置,我的配置修改如下:
num_classes: 2
然后修改路径:
train_input_reader: {
tf_record_input_reader {
input_path: "data/whsyxt_train.tfrecord"
}
label_map_path: "training/ssd_inception_v2_whsyxt/whsyxt_label_map.pbtxt"
}
eval_input_reader: {
tf_record_input_reader {
input_path: "data/whsyxt_validation.tfrecord"
}
label_map_path: "training/ssd_inception_v2_whsyxt/whsyxt_label_map.pbtxt"
shuffle: false
num_readers: 1
num_epochs: 1
}


按照我的样子照葫芦画瓢就行了,其实就是改一下类别数目和路径

# SSD with Inception v2 configuration for MSCOCO Dataset.
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.

model {
ssd {
num_classes: 2
box_coder {
faster_rcnn_box_coder {
y_scale: 10.0
x_scale: 10.0
height_scale: 5.0
width_scale: 5.0
}
}
matcher {
argmax_matcher {
matched_threshold: 0.5
unmatched_threshold: 0.5
ignore_thresholds: false
negatives_lower_than_unmatched: true
force_match_for_each_row: true
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
num_layers: 6
min_scale: 0.2
max_scale: 0.95
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
aspect_ratios: 3.0
aspect_ratios: 0.3333
reduce_boxes_in_lowest_layer: true
}
}
image_resizer {
fixed_shape_resizer {
height: 300
width: 300
}
}
box_predictor {
convolutional_box_predictor {
min_depth: 0
max_depth: 0
num_layers_before_predictor: 0
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 3
box_code_size: 4
apply_sigmoid_to_scores: false
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
}
}
}
feature_extractor {
type: 'ssd_inception_v2'
min_depth: 16
depth_multiplier: 1.0
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
loss {
classification_loss {
weighted_sigmoid {
anchorwise_output: true
}
}
localization_loss {
weighted_smooth_l1 {
anchorwise_output: true
}
}
hard_example_miner {
num_hard_examples: 3000
iou_threshold: 0.99
loss_type: CLASSIFICATION
max_negatives_per_positive: 3
min_negatives_per_image: 0
}
classification_weight: 1.0
localization_weight: 1.0
}
normalize_loss_by_num_matches: true
post_processing {
batch_non_max_suppression {
score_threshold: 1e-8
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 100
}
score_converter: SIGMOID
}
}
}

train_config: {
batch_size: 24
optimizer {
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 800720
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
}
fine_tune_checkpoint: "ssd_inception_v2_coco_2017_11_17/model.ckpt"
from_detection_checkpoint: true
# Note: The below line limits the training process to 200K steps, which we
# empirically found to be sufficient enough to train the pets dataset. This
# effectively bypasses the learning rate schedule (the learning rate will
# never decay). Remove the below line to train indefinitely.
num_steps: 200000
data_augmentation_options {
random_horizontal_flip {
}
}
data_augmentation_options {
ssd_random_crop {
}
}
}

train_input_reader: { tf_record_input_reader { input_path: "data/whsyxt_train.tfrecord" } label_map_path: "training/ssd_inception_v2_whsyxt/whsyxt_label_map.pbtxt" }

eval_config: {
num_examples: 8000
# Note: The below line limits the evaluation process to 10 evaluations.
# Remove the below line to evaluate indefinitely.
max_evals: 10
}

eval_input_reader: { tf_record_input_reader { input_path: "data/whsyxt_validation.tfrecord" } label_map_path: "training/ssd_inception_v2_whsyxt/whsyxt_label_map.pbtxt" shuffle: false num_readers: 1 num_epochs: 1 }
然后回退到object detection目录,我的训练命令为:
python3 train.py \
--logtostderr  \
--train_dir=training/ssd_inception_v2_whsyxt \
--pipeline_config_path=training/ssd_inception_v2_whsyxt/ssd_inception_v2_coco.config
训练完毕后,会在training/ssd_inception_v2_whsyxt目录下产生很多的ckpt文件。
我们用训练好的模型做预测时,这里在object detection目录下创建inference_graph/ssd_whsyxt_inference_graph目录,用于存放模型的预测文件,我的运行命令为:
python3 export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path training/ssd_inception_v2_whsyxt/ssd_inception_v2_coco.config \
--trained_checkpoint_prefix training/ssd_inception_v2_whsyxt/model.ckpt-146589 \
--output_directory inference_graph/ssd_whsyxt_inference_graph
这时输出的文件,我们就可以拿来做预测了,预测的代码我也仿照官方的代码写了一个,我的是测试视频或者打开摄像头的,文件名为object_detection_tutorial.py,这里贴出来给大家参考:
# coding: utf-8

# # Object Detection Demo
# Welcome to the object detection inference walkthrough!  This notebook will walk you step by step through the process of using a pre-trained model to detect objects in an image. Make sure to follow the [installation instructions](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md) before you start.

# # Imports

import numpy as np
import os
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
import cv2
cap = cv2.VideoCapture(0)  #打开摄像头
# cap = cv2.VideoCapture("car.mp4")
# cap = cv2.VideoCapture("DJI_0004.MOV")
import time
#if tf.__version__ != '1.4.0':
# raise ImportError('Please upgrade your tensorflow installation to v1.4.0!')

# ## Env setup

# This is needed to display the images.
# get_ipython().magic(u'matplotlib inline')

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")

# ## Object detection imports
# Here are the imports from the object detection module.

from utils import label_map_util

from utils import visualization_utils as vis_util

# # Model preparation

# ## Variables
#
# Any model exported using the `export_inference_graph.py` tool can be loaded here simply by changing `PATH_TO_CKPT` to point to a new .pb file.
#
# By default we use an "SSD with Mobilenet" model here. See the [detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md) for a list of other models that can be run out-of-the-box with varying speeds and accuracies.

# What model to download.
# MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
# MODEL_FILE = MODEL_NAME + '.tar.gz'
MODEL_NAME = 'inference_graph/ssd_whsyxt_inference_graph'
# DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
# PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
PATH_TO_LABELS = os.path.join('training/ssd_inception_v2_whsyxt', 'whsyxt_label_map.pbtxt')

NUM_CLASSES = 2

# ## Download Model

# opener = urllib.request.URLopener()
# opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
# tar_file = tarfile.open(MODEL_FILE)
# for file in tar_file.getmembers():
#   file_name = os.path.basename(file.name)
#   if 'frozen_inference_graph.pb' in file_name:
#     tar_file.extract(file, os.getcwd())

# ## Load a (frozen) Tensorflow model into memory.

detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')

# ## Loading label map
# Label maps map indices to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`.  Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

# ## Helper code

def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)

# # Detection

# For the sake of simplicity we will use only 2 images:
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
'''
PATH_TO_TEST_IMAGES_DIR = 'test_images'
# PATH_TO_TEST_IMAGES_DIR = 'demo_2017117'
images=os.listdir(PATH_TO_TEST_IMAGES_DIR)
#TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
TEST_IMAGE_PATHS=[]
for image_name in images:
if(str(image_name.split(".")[-1])=="jpg"):
TEST_IMAGE_PATHS.append(os.path.join(PATH_TO_TEST_IMAGES_DIR,image_name))
# TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR,images) for i in range(1, 3) ]
'''
# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)

with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
while True:
ret, image_np = cap.read()
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')

start = time.time()
# Actual detection.
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
end = time.time()
# Time elapsed
seconds = end - start
print( "Time taken : {0} seconds".format(seconds))
# Calculate frames per second
fps  = 1 / seconds;
print( "Estimated frames per second : {0}".format(fps));
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8)
cv2.imshow('object detection', cv2.resize(image_np, (800,600)))
if cv2.waitKey(25) & 0xFF == ord('q'):
cv2.destroyAllWindows()
break

参考文献

[1].Tensorflow Object Detection API.https://github.com/tensorflow/models/tree/master/research/object_detection
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: